{-# LANGUAGE LambdaCase , CPP , TemplateHaskell #-} module Data.StructuralTraversal.TH where import Language.Haskell.TH import Data.StructuralTraversal.Class import Data.Maybe import Control.Monad import Control.Applicative -- | Derive SmartTrav deriveStructTrav :: Name -> Q [Dec] deriveStructTrav nm = reify nm >>= (\case TyConI dt -> case dt of DataD _ tyConName typArgs _ dataCons _ -> createInstance tyConName typArgs dataCons NewtypeD _ tyConName typArgs _ dataCon _ -> createInstance tyConName typArgs [dataCon] _ -> fail "Unsupported data type" _ -> fail "Expected the name of a data type or newtype" ) createInstance :: Name -> [TyVarBndr] -> [Con] -> Q [Dec] createInstance tyConName typArgs dataCons = do (upClauses, preds) <- unzip <$> mapM (createClause upName) dataCons (downClauses, _) <- unzip <$> mapM (createClause downName) dataCons return [InstanceD Nothing (concat preds) (AppT (ConT className) (foldl AppT (ConT tyConName) (map getTypVarTyp (init typArgs)))) [FunD upName upClauses, FunD downName downClauses]] where -- | Gets the variable that is traversed on varToTraverseOn :: Q Name varToTraverseOn = case reverse typArgs of (PlainTV last : _) -> return last (KindedTV last StarT : _) -> return last (KindedTV last _ : _) -> fail $ "The kind of the last type parameter is not *" [] -> fail $ "The kind of type " ++ show tyConName ++ " is *" -- | Creates a clause for a constructor, the needed context is also generated createClause :: Name -> Con -> Q (Clause,[Pred]) createClause funN (RecC conName conArgs) = createClause' funN conName (map (\(_,_,r) -> r) conArgs) createClause funN (NormalC conName conArgs) = createClause' funN conName (map snd conArgs) createClause funN (InfixC conArg1 conName conArg2) = createClause' funN conName [snd conArg1, snd conArg2] createClause' :: Name -> Name -> [Type] -> Q (Clause, [Pred]) createClause' funN conName argTypes = do bindedNames <- replicateM (length argTypes) (newName "p") (handleParams,ctx) <- unzip <$> zipWithM (processParam funN) bindedNames argTypes return $ (Clause [ VarP desc, VarP asc, VarP f , ConP conName (map VarP bindedNames) ] (NormalB (createExpr conName handleParams)) [] , concat ctx) -- | Creates an expression for the body of a smartTrav clause -- using the matches created for parameters createExpr :: Name -> [Exp] -> Exp createExpr ctrName [] = AppE applPure $ ConE ctrName createExpr ctrName (param1:params) = foldl (\coll new -> InfixE (Just coll) applStar (Just new)) (InfixE (Just $ ConE ctrName) applDollar (Just param1)) params applStar = VarE (mkName "Control.Applicative.<*>") applDollar = VarE (mkName "Control.Applicative.<$>") applPure = VarE (mkName "Control.Applicative.pure") className = ''StructuralTraversable upName = 'traverseUp downName = 'traverseDown desc = mkName "desc" asc = mkName "asc" f = mkName "f" -- | Creates the expression and the predicate for a parameter processParam :: Name -> Name -> Type -> Q (Exp, [Pred]) processParam _ name (VarT v) -- found the type variable to traverse on = do travV <- varToTraverseOn if v == travV then return (AppE (VarE f) (VarE name), []) else return (AppE applPure (VarE name), []) processParam funN name (AppT tf ta) = do expr <- createExprForHighKind' funN name (VarE f) ta case expr of Just (e,ctx) -> return (e, if isTypVar tf then createConstraint className tf : ctx else ctx) Nothing -> return (AppE applPure (VarE name), []) processParam _ name _ = return (AppE applPure (VarE name), []) -- | Create an expression and a context for a higher kinded parameter createExprForHighKind' :: Name -> Name -> Exp -> Type -> Q (Maybe (Exp, [Pred])) createExprForHighKind' funN name f (AppT tf ta) = do res <- createExprForHighKind' funN name (applExpr funN f) ta case res of Just (e,ctx) -> return $ Just (e, if isTypVar tf then createConstraint className tf : ctx else ctx) Nothing -> return Nothing createExprForHighKind' funN name f (VarT v) = do travV <- varToTraverseOn if v == travV then return $ Just (applExpr funN f `AppE` (VarE name), []) else return Nothing createExprForHighKind' _ _ _ _ = return Nothing applExpr funN f = (((VarE funN) `AppE` (VarE desc)) `AppE` (VarE asc)) `AppE` f -- Predicates are types from GHC 7.10 createConstraint :: Name -> Type -> Pred createConstraint name typ #if __GLASGOW_HASKELL__ >= 710 = AppT (ConT name) typ #else = ClassP name [typ] #endif isTypVar :: Type -> Bool isTypVar (VarT _) = True isTypVar _ = False getTypVarTyp :: TyVarBndr -> Type getTypVarTyp (PlainTV n) = VarT n getTypVarTyp (KindedTV n _) = VarT n thExamine :: Q [Dec] -> Q [Dec] thExamine decl = do d <- decl runIO (putStrLn (pprint d)) return d