{-# LANGUAGE LambdaCase
           , CPP
           , TemplateHaskell           
           #-}
module Data.StructuralTraversal.TH where
  
import Language.Haskell.TH
import Data.StructuralTraversal.Class
import Data.Maybe
import Control.Monad
import Control.Applicative
  
-- | Derive SmartTrav
deriveStructTrav :: Name -> Q [Dec]
deriveStructTrav nm = reify nm >>= (\case
  TyConI dt -> case dt of
    DataD _ tyConName typArgs _ dataCons _ -> 
      createInstance tyConName typArgs dataCons  
    NewtypeD _ tyConName typArgs _ dataCon _ -> 
      createInstance tyConName typArgs [dataCon]  
    _ -> fail "Unsupported data type"
  _ -> fail "Expected the name of a data type or newtype"
  )
  
createInstance :: Name -> [TyVarBndr] -> [Con] -> Q [Dec]
createInstance tyConName typArgs dataCons   
          = do (upClauses, preds) <- unzip <$> mapM (createClause upName) dataCons
               (downClauses, _) <- unzip <$> mapM (createClause downName) dataCons
               return [InstanceD Nothing
                                 (concat preds) 
                                 (AppT (ConT className) 
                                 (foldl AppT (ConT tyConName) 
                                        (map getTypVarTyp (init typArgs))))
                                 [FunD upName upClauses, FunD downName downClauses]]
  where -- | Gets the variable that is traversed on
        varToTraverseOn :: Q Name
        varToTraverseOn = case reverse typArgs of 
          (PlainTV last : _)        -> return last
          (KindedTV last StarT : _) -> return last
          (KindedTV last _ : _)     -> fail $ "The kind of the last type parameter is not *"
          []                        -> fail $ "The kind of type " ++ show tyConName ++ " is *"
  
        -- | Creates a clause for a constructor, the needed context is also generated
        createClause :: Name -> Con -> Q (Clause,[Pred])
        createClause funN (RecC conName conArgs) 
          = createClause' funN conName (map (\(_,_,r) -> r) conArgs)
        createClause funN (NormalC conName conArgs) 
          = createClause' funN conName (map snd conArgs)
        createClause funN (InfixC conArg1 conName conArg2)
          = createClause' funN conName [snd conArg1, snd conArg2]
        
        createClause' :: Name -> Name -> [Type] -> Q (Clause, [Pred])
        createClause' funN conName argTypes
          = do bindedNames <- replicateM (length argTypes) (newName "p")
               (handleParams,ctx) <- unzip <$> zipWithM (processParam funN)
                                                        bindedNames argTypes
               return $ (Clause [ VarP desc, VarP asc, VarP f
                               , ConP conName (map VarP bindedNames) ] 
                               (NormalB (createExpr conName handleParams)) []
                        , concat ctx)


        -- | Creates an expression for the body of a smartTrav clause
        -- using the matches created for parameters
        createExpr :: Name -> [Exp] -> Exp
        createExpr ctrName []
          = AppE applPure $ ConE ctrName
        createExpr ctrName (param1:params)
          = foldl (\coll new -> InfixE (Just coll) applStar (Just new)) 
                  (InfixE (Just $ ConE ctrName) applDollar (Just param1))
                  params
        
        applStar = VarE (mkName "Control.Applicative.<*>")
        applDollar = VarE (mkName "Control.Applicative.<$>")
        applPure = VarE (mkName "Control.Applicative.pure")
        
        className = ''StructuralTraversable
        upName = 'traverseUp
        downName = 'traverseDown
        desc = mkName "desc"
        asc = mkName "asc"
        f = mkName "f"
       
        -- | Creates the expression and the predicate for a parameter
        processParam :: Name -> Name -> Type -> Q (Exp, [Pred])
        processParam _ name (VarT v)  -- found the type variable to traverse on
          = do travV <- varToTraverseOn 
               if v == travV then return (AppE (VarE f) (VarE name), [])
                             else return (AppE applPure (VarE name), [])
        processParam funN name (AppT tf ta) = do
          expr <- createExprForHighKind' funN name (VarE f) ta
          case expr of Just (e,ctx) -> return (e, if isTypVar tf then createConstraint className tf : ctx
                                                                 else ctx)
                       Nothing -> return (AppE applPure (VarE name), [])
        processParam _ name _
          = return (AppE applPure (VarE name), [])
  
        -- | Create an expression and a context for a higher kinded parameter
        createExprForHighKind' :: Name -> Name -> Exp -> Type -> Q (Maybe (Exp, [Pred]))
        createExprForHighKind' funN name f (AppT tf ta)
          = do res <- createExprForHighKind' funN name (applExpr funN f) ta
               case res of Just (e,ctx) -> return $ Just (e, if isTypVar tf 
                                                               then createConstraint className tf : ctx
                                                               else ctx)
                           Nothing -> return Nothing
        createExprForHighKind' funN name f (VarT v)
          = do travV <- varToTraverseOn 
               if v == travV then
                 return $ Just (applExpr funN f `AppE` (VarE name), [])
                else return Nothing
        createExprForHighKind' _ _ _ _
          = return Nothing
          
        applExpr funN f = (((VarE funN) `AppE` (VarE desc)) `AppE` (VarE asc)) 
                          `AppE` f
                       
-- Predicates are types from GHC 7.10
createConstraint :: Name -> Type -> Pred
createConstraint name typ
#if __GLASGOW_HASKELL__ >= 710
  = AppT (ConT name) typ
#else
  = ClassP name [typ]
#endif
                       
               
isTypVar :: Type -> Bool
isTypVar (VarT _) = True
isTypVar _ = False               
  
getTypVarTyp :: TyVarBndr -> Type
getTypVarTyp (PlainTV n) = VarT n
getTypVarTyp (KindedTV n _) = VarT n
  
  
thExamine :: Q [Dec] -> Q [Dec]
thExamine decl = do d <- decl
                    runIO (putStrLn (pprint d))
                    return d