module Data.Comp.Trans.DeriveMulti ( deriveMulti ) where import Control.Lens ( traverse, _1, _2, _3, (&), (%~), (%%~) ) import Control.Monad ( liftM ) import Data.Functor ( (<$>) ) import Language.Haskell.TH.Syntax import Language.Haskell.TH.ExpandSyns ( expandSyns ) import Data.Comp.Trans.Names ( baseTypes, transName, nameLab, getLab ) deriveMulti :: Name -> Q [Dec] deriveMulti n = do inf <- reify n case inf of TyConI (DataD _ nm [] cons _) -> mkGADT nm cons TyConI (NewtypeD _ nm [] con _) -> mkGADT nm [con] _ -> do reportError $ "Attempted to derive multi-sorted compositional data type for " ++ show n ++ ", which is not a nullary datatype" return [] mkGADT :: Name -> [Con] -> Q [Dec] mkGADT n cons = do e <- newName "e" i <- newName "i" let n' = transName n cons' <- mapM (mkCon n' e i) cons return $ [DataD [] n' [KindedTV e (AppT (AppT ArrowT StarT) StarT), PlainTV i] cons' [] ,DataD [] (nameLab n) [] [] [] ] mkCon :: Name -> Name -> Name -> Con -> Q Con mkCon l e i (NormalC n sts) = ForallC [] ctx <$> inner where ctx = [EqualP (VarT i) (ConT $ nameLab l)] sts' = sts & (traverse._2) %%~ unfixType e inner = liftM (NormalC (transName n)) sts' mkCon l e i (RecC n vsts) = ForallC [] ctx <$> inner where ctx = [EqualP (VarT i) (ConT $ nameLab l)] vsts' = vsts & (traverse._1) %~ transName vsts'' = vsts' & (traverse._3) %%~ unfixType e inner = liftM (RecC (transName n)) vsts'' mkCon _ _ _ c = fail $ "Attempted to derive multi-sorted compositional datatype for something with non-normal constructors: " ++ show c unfixType :: Name -> Type -> Q Type unfixType _ t | elem t baseTypes = return t unfixType e t = do t' <- expandSyns t >>= getLab return $ AppT (VarT e) t'