{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE MonomorphismRestriction #-} {-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS -fwarn-missing-signatures #-} {-| Example usage: @ import Generics.MultiRec import Generics.MultiRec.TH.Alt import Data.Tree data TheFam :: (* -> *) where Tree_Int :: TheFam (Tree Int) Forest_Int :: TheFam (Forest Int) $('deriveEverything' ('DerivOptions' [ ( [t| Tree Int |], \"Tree_Int\" ) , ( [t| Forest Int |], \"Forest_Int\" ) ] \"TheFam\" (\\t c -> \"CONSTRUCTOR_\" ++ t ++ \"_\" ++ c) \"ThePF\" True ) ) type instance 'PF' TheFam = ThePF @ -} module Generics.MultiRec.TH.Alt ( DerivOptions(..), deriveEverything, ) where import Generics.MultiRec.TH.Alt.DerivOptions(DerivOptions(..)) import THUtils(AppliedTyCon, (@@), (@@@), toAppliedTyCon, fromAppliedTyCon, atc2constructors, pprintUnqual, sMatch, sClause, cleanConstructorName) import BalancedFold(balancedFold, ascendFromLeaf) import MonadRQ(RQ, message, messageReport, liftq, foreachType, foreachTypeNumbered, runRQ) import Generics.MultiRec.Base((:>:)(..), C(..), El(..), (:=:)(..), I0(I0), (:*:)(..), (:+:)(..), EqS(..), Fam(..), I(I), K(K), U(..)) import Generics.MultiRec.Constructor(Associativity(..), Fixity(..), Constructor(..)) import Control.Monad.Reader(Monad(return, fail, (>>)), Functor(..), (=<<), mapM, sequence, liftM, zipWithM, asks) import Language.Haskell.TH.Syntax(Lift(..)) import Language.Haskell.TH(newName, mkName, wildP, clause, conE, appE, normalB, funD, dataD, instanceD, cxt, conT, appT, Exp(VarE, SigE, LamE, ConE, CaseE, AppE), Match, Clause, Q, Pat(WildP, VarP, ConP), TypeQ, Type(ConT), Dec(TySynD, InstanceD, FunD), Name, Con(RecC, NormalC, InfixC), FixityDirection(..), Info(DataConI), nameBase, reify, stringE) import Data.Map(lookup, elems) import Control.Applicative((<$>)) import qualified Data.Map as Map import qualified Language.Haskell.TH as TH bALANCED_MODE :: Bool bALANCED_MODE = False deriveEverything :: DerivOptions [(TypeQ, String)] -> Q [Dec] deriveEverything opts = do -- let x | mkSanityChecks opts = makeSanityChecks -- | otherwise = return [] runRQ (concat <$> sequence [deriveConstructors, deriveFamily]) opts -- | Given a list of datatype names, derive datatypes and -- instances of class 'Constructor'. deriveConstructors :: RQ [Dec] deriveConstructors = concat <$> foreachType constrInstance -- | Given the name of the index GADT, the names of the -- types in the family, and the name (as string) for the -- pattern functor to derive, generate the 'Ix' and 'PF' -- instances. /IMPORTANT/: It is assumed that the constructors -- of the GADT have the same names as the datatypes in the -- family. deriveFamily :: RQ [Dec] deriveFamily = do pf <- derivePF el <- deriveEl fam <- deriveFam eq <- deriveEqS return $ pf ++ el ++ fam ++ eq -- | Derive only the 'PF' instance. Not needed if 'deriveFamily' -- is used. derivePF :: RQ [Dec] derivePF = do branches <- foreachType pfType pfn <- asks patternFunctorName let pf = [TySynD (mkName pfn) [] (sumT branches)] famName <- asks indexGadtName -- message $ -- ( "*** The pattern functor is:\n" -- ++ pprint (cutNames pf) -- ++ "\n\n\n" -- ) messageReport ( "Reminder: Don't forget to add this line manually:\n" ++ " type instance PF "++famName++" = "++pfn ) return pf sumT :: [Type] -> Type sumT | bALANCED_MODE = balancedSumT | otherwise = rightSumT rightSumT :: [Type] -> Type rightSumT = foldr1 plusT balancedSumT :: [Type] -> Type balancedSumT = balancedFold plusT plusT :: Type -> Type -> Type plusT a b = ConT ''(:+:) @@ a @@ b prodT :: [Type] -> Type prodT = foldr1 timesT timesT :: Type -> Type -> Type timesT a b = ConT ''(:*:) @@ a @@ b -- | Derive only the 'El' instances. Not needed if 'deriveFamily' -- is used. deriveEl :: RQ [Dec] deriveEl = foreachType elInstance indexGadtType :: RQ Type indexGadtType = ConT . mkName <$> asks indexGadtName -- | Dervie only the 'Fam' instance. Not needed if 'deriveFamily' -- is used. deriveFam :: RQ [Dec] deriveFam = do fcs <- liftM concat $ foreachTypeNumbered mkFrom tcs <- foreachTypeNumbered mkTo s <- indexGadtType return [ InstanceD [] (ConT ''Fam @@ s) [FunD 'from fcs, FunD 'to tcs] ] -- | Derive only the 'EqS' instance. Not needed if 'deriveFamily' -- is used. deriveEqS :: RQ [Dec] deriveEqS = do s <- indexGadtType ns <- elems <$> asks familyTypes return [ InstanceD [] (ConT ''EqS @@ s) [FunD 'eqS (trues ns ++ falses ns)] ] where trueClause n = sClause [ConP (mkName n) [], ConP (mkName n) []] ((ConE 'Just `AppE` ConE 'Refl)) falseClause = sClause [WildP, WildP] ((ConE 'Nothing)) trues ns = fmap trueClause ns falses ns = if length (trues ns) == 1 then [] else [falseClause] constrInstance :: (AppliedTyCon,String) -> RQ [Dec] constrInstance (atc,s) = do cs <- liftq (atc2constructors atc) -- runIO (print i) ds <- mapM (mkData s) cs is <- mapM (mkInstance s) cs return $ ds ++ is stripRecordNames :: Con -> Con stripRecordNames (RecC n f) = NormalC n (fmap (\(_, s, t) -> (s, t)) f) stripRecordNames c = c -- TODO: Handle colons in the constructor name mkData :: String -> Con -> RQ Dec mkData s (NormalC n _) = do modifier <- asks constructorNameModifier liftq $ dataD (cxt []) (mkName . modifier s . cleanConstructorName . nameBase $ n) [] [] [] mkData s r@(RecC _ _) = mkData s (stripRecordNames r) mkData s (InfixC t1 n t2) = mkData s (NormalC n [t1,t2]) instance Lift Fixity where lift Prefix = conE 'Prefix lift (Infix a n) = conE 'Infix `appE` [| a |] `appE` [| n |] instance Lift Associativity where lift LeftAssociative = conE 'LeftAssociative lift RightAssociative = conE 'RightAssociative lift NotAssociative = conE 'NotAssociative mkInstance :: String -> Con -> RQ Dec mkInstance s (NormalC n _) = do modifier <- asks constructorNameModifier let n' = modifier s . cleanConstructorName . nameBase $ n liftq $ instanceD (cxt []) (appT (conT ''Constructor) (conT . mkName $ n')) [funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []]] mkInstance s r@(RecC _ _) = mkInstance s (stripRecordNames r) mkInstance s (InfixC t1 n t2) = do modifier <- asks constructorNameModifier let n' = modifier s . cleanConstructorName . nameBase $ n i <- liftq (reify n) let fi = case i of DataConI _ _ _ f -> convertFixity f _ -> Prefix liftq $ instanceD (cxt []) (appT (conT ''Constructor) (conT $ mkName n')) [funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []], funD 'conFixity [clause [wildP] (normalB [| fi |]) []]] where convertFixity (TH.Fixity n d) = Infix (convertDirection d) n convertDirection InfixL = LeftAssociative convertDirection InfixR = RightAssociative convertDirection InfixN = NotAssociative pfType :: (AppliedTyCon,String) -> RQ Type pfType (atc,s) = do -- runIO $ putStrLn $ "processing " ++ show n cs <- liftq (atc2constructors atc) guardEmptyData cs atc b <- sumT <$> mapM (pfCon s) cs return $ ConT ''(:>:) @@ b @@ fromAppliedTyCon atc pfCon :: String -> Con -> RQ Type pfCon s (NormalC n fs) = do modifier <- asks constructorNameModifier let n' = mkName . modifier s . cleanConstructorName . nameBase $ n fieldResults <- mapM (pfField . snd) fs let rest = case fs of [] -> ConT ''U _ -> prodT fieldResults return $ ConT ''C @@ ConT n' @@ rest pfCon s r@(RecC _ _) = pfCon s (stripRecordNames r) pfCon s (InfixC t1 n t2) = pfCon s (NormalC n [t1,t2]) pfField :: Type -> RQ Type pfField t = ifInFamily t (ConT ''I @@ t) (ConT ''K @@ t) lookupFam :: Type -> RQ (Maybe String) lookupFam t = do ts <- asks familyTypes t' <- liftq $ toAppliedTyCon t let res = case t' of Right t'' -> Map.lookup t'' ts Left _ -> Nothing -- message ("familyTypes = "++show ts) -- message ("lookupFam "++show t'++" = "++show res) return res ifInFamily :: Type -> a -> a -> RQ a ifInFamily n x y = ifInFamily' n (return x) (return y) ifInFamily' :: Type -> RQ a -> RQ a -> RQ a ifInFamily' t x y = maybe y (const x) =<< lookupFam t elInstance :: (AppliedTyCon,String) -> RQ Dec elInstance x@(atc,_) = do s <- indexGadtType prf <- mkProof x return $ InstanceD [] (ConT ''El @@ s @@ fromAppliedTyCon atc) [prf] mkFrom :: Int -> Int -> (AppliedTyCon,String) -> RQ [Clause] mkFrom m i (atc,s) = do -- ns <- fmap mkName . elems <$> asks familyTypes -- runIO $ putStrLn $ "processing " ++ show n cs <- liftq (atc2constructors atc) let wrapE = (\e -> lrE m i (ConE 'Tag @@@ e)) dn = mkName s -- (nameBase n) zipWithM (fromCon wrapE dn (length cs)) [0..] cs mkTo :: Int -> Int -> (AppliedTyCon,String) -> RQ Clause mkTo m i (atc,s) = do -- ns <- fmap mkName . elems <$> asks familyTypes -- runIO $ putStrLn $ "processing " ++ show n cs <- liftq (atc2constructors atc) pfname <- mkName <$> asks patternFunctorName let -- typeOfLamE = ArrowT @@ -- (ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc) @@ -- (fromAppliedTyCon atc) matchesOfCons <- zipWithM (toCon (length cs) atc) [0..] cs xvar <- liftq (newName "x") convar <- liftq (newName "con") typeOfConvar <- do t0 <- pfType (atc,s) return (t0 @@ ConT ''I0 @@ fromAppliedTyCon atc) let typeOfXvar = ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc body = LamE [VarP xvar] (CaseE (VarE xvar `SigE` typeOfXvar) [sMatch (lrP m i (VarP convar)) (CaseE (VarE convar `SigE` typeOfConvar) matchesOfCons) ] ) return (sClause [ConP (mkName s) []] body ) mkProof :: (AppliedTyCon,String) -> RQ Dec mkProof (_,s) = return $ FunD 'proof [sClause [] (ConE (mkName s)) ] fromCon :: (Exp -> Exp) -> Name -> Int -> Int -> Con -> RQ Clause fromCon wrap n m i (NormalC cn []) = return $ -- Nullary constructor case sClause [ConP n [], ConP cn []] (wrap . lrE m i $ ConE 'C @@@ ConE 'U) fromCon wrap n m i (NormalC cn fs) = do rhs <- zipWithM fromField [0..] (snd <$> fs) return $ sClause [ ConP n [], ConP cn (fmap (VarP . field) [0..length fs - 1]) ] (wrap . lrE m i $ ConE 'C @@@ foldr1 prod rhs) where prod x y = ConE '(:*:) @@@ x @@@ y fromCon wrap n m i r@(RecC _ _) = fromCon wrap n m i (stripRecordNames r) fromCon wrap n m i (InfixC t1 cn t2) = fromCon wrap n m i (NormalC cn [t1,t2]) toCon :: Int -- ^ Number of constructors -> AppliedTyCon -> Int -- ^ Index of this constructor -> Con -> RQ Match toCon m atc i (NormalC cn []) = return $ -- Nullary constructor case sMatch (ConP 'Tag [lrP m i $ ConP 'C [ConP 'U []]]) ( ConE cn -- SigE (ConE cn) (fromAppliedTyCon atc) ) toCon m atc i (NormalC cn fs) = -- runIO (putStrLn ("constructor " ++ show ix)) >> do lhs <- zipWithM toField [0..] (fmap snd fs) return $ sMatch (ConP 'Tag [lrP m i $ ConP 'C [foldr1 prod lhs]]) ( -- SigE ( foldl AppE (ConE cn) (fmap (VarE . field) [0..length fs - 1]) -- ) -- (fromAppliedTyCon atc) ) where prod x y = ConP '(:*:) [x,y] toCon m atc i r@(RecC _ _) = toCon m atc i (stripRecordNames r) toCon m atc i (InfixC t1 cn t2) = toCon m atc i (NormalC cn [t1,t2]) fromField :: Int -> Type -> RQ Exp fromField nr t = ifInFamily' t (return (ConE 'I @@@ (ConE 'I0 @@@ VarE (field nr)))) (message ("* Info: Type not in family: " ++ pprintUnqual t) >> -- helper t >> return (ConE 'K @@@ VarE (field nr))) toField :: Int -> Type -> RQ Pat toField nr t = ifInFamily t (ConP 'I [ConP 'I0 [VarP (field nr)]]) (ConP 'K [VarP (field nr)]) field :: Int -> Name field n = mkName $ "f" ++ show n lrP :: Int -> Int -> ( Pat -> Pat) lrP m i p | bALANCED_MODE = ascendFromLeaf (ConP 'L . (:[] {- robot monkey -})) (ConP 'R . (:[])) p m i lrP 1 0 p = p lrP m 0 p = ConP 'L [p] lrP m i p = ConP 'R [lrP (m-1) (i-1) p] lrE :: Int -> Int -> ( Exp -> Exp) lrE m i e | bALANCED_MODE = ascendFromLeaf (ConE 'L @@@) (ConE 'R @@@) e m i lrE 1 0 e = e lrE m 0 e = ConE 'L @@@ e lrE m i e = ConE 'R @@@ lrE (m-1) (i-1) e guardEmptyData :: [Con] -> AppliedTyCon -> RQ () guardEmptyData [] atc = fail ("Empty types not supported yet ("++ show (fromAppliedTyCon atc)) guardEmptyData _ atc = return () -- helper t = do -- Right (AppliedTyCon n args) <- liftq (toAppliedTyCon t) -- let prefix = "Prf_" -- str <- if n == ''[] -- then do -- Right (AppliedTyCon n1 _) <- liftq (toAppliedTyCon (head args)) -- return ("T("++prefix++"List"++nameBase n1 -- ++",["++pprintUnqual (head args)++"])") -- else -- return ("T("++prefix++nameBase n -- ++","++pprintUnqual t++")") -- liftq . runIO $ appendFile "dump.dump" (str++"\n") noSigE :: Exp -> Type -> Exp x `noSigE` y = x -- makeSanityChecks :: RQ [Dec] -- makeSanityChecks = concat <$> foreachType makeSanityCheck -- makeSanityCheck :: (AppliedTyCon,String) -> RQ [Dec] -- makeSanityCheck (atc,s) = do -- famname <- mkName <$> asks indexGadtName -- let -- chkName = mkName ("sanityCheck"++s) -- return [ -- SigD chkName (ConT famname @@ fromAppliedTyCon atc) -- , ValD (VarP chkName) -- (NormalB (ConE (mkName s))) -- [] -- ]