module Data.StructuralTraversal.TH where
import Language.Haskell.TH
import Data.StructuralTraversal.Class
import Data.Maybe
import Control.Monad
import Control.Applicative
deriveStructTrav :: Name -> Q [Dec]
deriveStructTrav nm = reify nm >>= (\case
TyConI dt -> case dt of
DataD _ tyConName typArgs _ dataCons _ ->
createInstance tyConName typArgs dataCons
NewtypeD _ tyConName typArgs _ dataCon _ ->
createInstance tyConName typArgs [dataCon]
_ -> fail "Unsupported data type"
_ -> fail "Expected the name of a data type or newtype"
)
createInstance :: Name -> [TyVarBndr] -> [Con] -> Q [Dec]
createInstance tyConName typArgs dataCons
= do (upClauses, preds) <- unzip <$> mapM (createClause upName) dataCons
(downClauses, _) <- unzip <$> mapM (createClause downName) dataCons
return [InstanceD Nothing
(concat preds)
(AppT (ConT className)
(foldl AppT (ConT tyConName)
(map getTypVarTyp (init typArgs))))
[FunD upName upClauses, FunD downName downClauses]]
where
varToTraverseOn :: Q Name
varToTraverseOn = case reverse typArgs of
(PlainTV last : _) -> return last
(KindedTV last StarT : _) -> return last
(KindedTV last _ : _) -> fail $ "The kind of the last type parameter is not *"
[] -> fail $ "The kind of type " ++ show tyConName ++ " is *"
createClause :: Name -> Con -> Q (Clause,[Pred])
createClause funN (RecC conName conArgs)
= createClause' funN conName (map (\(_,_,r) -> r) conArgs)
createClause funN (NormalC conName conArgs)
= createClause' funN conName (map snd conArgs)
createClause funN (InfixC conArg1 conName conArg2)
= createClause' funN conName [snd conArg1, snd conArg2]
createClause' :: Name -> Name -> [Type] -> Q (Clause, [Pred])
createClause' funN conName argTypes
= do bindedNames <- replicateM (length argTypes) (newName "p")
(handleParams,ctx) <- unzip <$> zipWithM (processParam funN)
bindedNames argTypes
return $ (Clause [ VarP desc, VarP asc, VarP f
, ConP conName (map VarP bindedNames) ]
(NormalB (createExpr conName handleParams)) []
, concat ctx)
createExpr :: Name -> [Exp] -> Exp
createExpr ctrName []
= AppE applPure $ ConE ctrName
createExpr ctrName (param1:params)
= foldl (\coll new -> InfixE (Just coll) applStar (Just new))
(InfixE (Just $ ConE ctrName) applDollar (Just param1))
params
applStar = VarE (mkName "Control.Applicative.<*>")
applDollar = VarE (mkName "Control.Applicative.<$>")
applPure = VarE (mkName "Control.Applicative.pure")
className = ''StructuralTraversable
upName = 'traverseUp
downName = 'traverseDown
desc = mkName "desc"
asc = mkName "asc"
f = mkName "f"
processParam :: Name -> Name -> Type -> Q (Exp, [Pred])
processParam _ name (VarT v)
= do travV <- varToTraverseOn
if v == travV then return (AppE (VarE f) (VarE name), [])
else return (AppE applPure (VarE name), [])
processParam funN name (AppT tf ta) = do
expr <- createExprForHighKind' funN name (VarE f) ta
case expr of Just (e,ctx) -> return (e, if isTypVar tf then createConstraint className tf : ctx
else ctx)
Nothing -> return (AppE applPure (VarE name), [])
processParam _ name _
= return (AppE applPure (VarE name), [])
createExprForHighKind' :: Name -> Name -> Exp -> Type -> Q (Maybe (Exp, [Pred]))
createExprForHighKind' funN name f (AppT tf ta)
= do res <- createExprForHighKind' funN name (applExpr funN f) ta
case res of Just (e,ctx) -> return $ Just (e, if isTypVar tf
then createConstraint className tf : ctx
else ctx)
Nothing -> return Nothing
createExprForHighKind' funN name f (VarT v)
= do travV <- varToTraverseOn
if v == travV then
return $ Just (applExpr funN f `AppE` (VarE name), [])
else return Nothing
createExprForHighKind' _ _ _ _
= return Nothing
applExpr funN f = (((VarE funN) `AppE` (VarE desc)) `AppE` (VarE asc))
`AppE` f
createConstraint :: Name -> Type -> Pred
createConstraint name typ
#if __GLASGOW_HASKELL__ >= 710
= AppT (ConT name) typ
#else
= ClassP name [typ]
#endif
isTypVar :: Type -> Bool
isTypVar (VarT _) = True
isTypVar _ = False
getTypVarTyp :: TyVarBndr -> Type
getTypVarTyp (PlainTV n) = VarT n
getTypVarTyp (KindedTV n _) = VarT n
thExamine :: Q [Dec] -> Q [Dec]
thExamine decl = do d <- decl
runIO (putStrLn (pprint d))
return d