{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE MonomorphismRestriction #-} {-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS -fwarn-missing-signatures #-} {-| Example usage: @ import Generics.MultiRec import Generics.MultiRec.TH.Alt import Data.Tree data TheFam :: (* -> *) where Tree_Int :: TheFam (Tree Int) Forest_Int :: TheFam (Forest Int) $('deriveEverything' ('DerivOptions' [ ( [t| Tree Int |], \"Tree_Int\" ) , ( [t| Forest Int |], \"Forest_Int\" ) ] \"TheFam\" (\\t c -> \"CONSTRUCTOR_\" ++ t ++ \"_\" ++ c) \"ThePF\" True ) ) type instance 'PF' TheFam = ThePF @ -} module Generics.MultiRec.TH.Alt ( DerivOptions(..), deriveEverything, -- RQ, runRQ, -- deriveConstructors, -- deriveFamily, -- derivePF, -- deriveEl, -- deriveFam, -- deriveEqS ) where import Generics.MultiRec.Base import Generics.MultiRec.Constructor import Language.Haskell.TH hiding (Fixity()) import Language.Haskell.TH.Syntax (Lift(..)) import Language.Haskell.TH.ExpandSyns import Language.Haskell.TH.Ppr import Control.Monad import Control.Monad.Reader hiding (lift) import qualified Control.Monad.Reader as Reader import Control.Applicative import Control.Arrow import THUtils import Data.Map as Map import Data.Set as Set hiding(elems) import qualified Data.Foldable as Fold import BalancedFold bALANCED_MODE :: Bool bALANCED_MODE = True data DerivOptions ft = DerivOptions { -- | A list of: -- -- > (type quotation, name of corresponding constructor of the family GADT) -- -- This defines our mutually recursive family. The types must resolve to -- @data@types or @newtype@s of kind @*@ (type synonyms will be expanded). familyTypes :: ft -- | Name of the family GADT (this type has to be generated -- manually because TH doesn't support GADTs yet) , indexGadtName :: String -- | Scheme for producing names for the -- empty types corresponding to constructors. The first arg is the name -- of the type (as given in 'familyTypes'), the second arg is the name -- of the constructor (builtins will be called: @NIL@, @CONS@, @TUPLE2@, @TUPLE3@ ...) , constructorNameModifier :: String -> String -> String -- | Name of the pattern functor ('PF') to generate , patternFunctorName :: String -- | Print various informational messges? , verbose :: Bool -- , mkSanityChecks :: Bool } instance Functor DerivOptions where fmap f d = d { familyTypes = (f . familyTypes) d } cleanConstructorName :: [Char] -> [Char] cleanConstructorName c = if head c == '(' && last c == ')' then ("TUPLE"++show (length c-1)) else if c=="[]" then "NIL" else if c==":" then "CONS" else c message :: String -> RQ () message x = do b <- asks verbose when b (liftq . runIO . putStrLn $ x ++ "\n") messageReport :: String -> RQ () messageReport x = do b <- asks verbose when b (liftq . report False $ x ++ "\n") -- checkOptions :: DerivOptions -> Q () -- checkOptions (DerivOptions{..}) = -- do -- when (null familyTypes) (fail "empty family") type RQ = ReaderT (DerivOptions (Map AppliedTyCon String)) Q liftq :: Q a -> RQ a liftq = Reader.lift foreachType :: ((AppliedTyCon,String) -> RQ a) -> RQ [a] foreachType f = mapM f . Map.toList =<< asks familyTypes foreachTypeNumbered :: (Int -> Int -> (AppliedTyCon,String) -> RQ a) -> RQ [a] foreachTypeNumbered f = do ns <- Map.toList <$> asks familyTypes zipWithM (f (length ns)) [0..] ns collision :: AppliedTyCon -> String -> String -> a collision k a b = error ("collision : " ++ "\n key = "++pprintUnqual k ++ "\n values = "++show(a,b) ) runRQ :: RQ a -> DerivOptions [(TypeQ,String)] -> Q a runRQ x opts = do ft' <- sequence . fmap (\(x,y) -> x >>= (\x' -> return (x',y))) . familyTypes $ opts :: Q [(Type,String)] when (Prelude.null ft') (fail ("Empty family not supported.")) ft'' <- mapM (\(t,s) -> do t' <- toAppliedTyCon t case t' of Left err -> fail err Right t'' -> return (t'',s)) ft' let ft''' = Map.fromListWithKey collision ft'' runReaderT x (fmap (const ft''') opts) deriveEverything :: DerivOptions [(TypeQ, String)] -> Q [Dec] deriveEverything opts = do -- let x | mkSanityChecks opts = makeSanityChecks -- | otherwise = return [] runRQ (concat <$> sequence [deriveConstructors, deriveFamily]) opts -- | Given a list of datatype names, derive datatypes and -- instances of class 'Constructor'. deriveConstructors :: RQ [Dec] deriveConstructors = concat <$> foreachType constrInstance -- | Given the name of the index GADT, the names of the -- types in the family, and the name (as string) for the -- pattern functor to derive, generate the 'Ix' and 'PF' -- instances. /IMPORTANT/: It is assumed that the constructors -- of the GADT have the same names as the datatypes in the -- family. deriveFamily :: RQ [Dec] deriveFamily = do pf <- derivePF el <- deriveEl fam <- deriveFam eq <- deriveEqS return $ pf ++ el ++ fam ++ eq -- | Derive only the 'PF' instance. Not needed if 'deriveFamily' -- is used. derivePF :: RQ [Dec] derivePF = do branches <- foreachType pfType pfn <- asks patternFunctorName let pf = [TySynD (mkName pfn) [] (sumT branches)] famName <- asks indexGadtName -- message $ -- ( "*** The pattern functor is:\n" -- ++ pprint (cutNames pf) -- ++ "\n\n\n" -- ) messageReport ( "Reminder: Don't forget to add this line manually:\n" ++ " type instance PF "++famName++" = "++pfn ) return pf sumT :: [Type] -> Type sumT | bALANCED_MODE = balancedSumT | otherwise = rightSumT rightSumT :: [Type] -> Type rightSumT = foldr1 plusT balancedSumT :: [Type] -> Type balancedSumT = balancedFold plusT plusT :: Type -> Type -> Type plusT a b = ConT ''(:+:) @@ a @@ b prodT :: [Type] -> Type prodT = foldr1 timesT timesT :: Type -> Type -> Type timesT a b = ConT ''(:*:) @@ a @@ b -- | Derive only the 'El' instances. Not needed if 'deriveFamily' -- is used. deriveEl :: RQ [Dec] deriveEl = foreachType elInstance indexGadtType :: RQ Type indexGadtType = ConT . mkName <$> asks indexGadtName -- | Dervie only the 'Fam' instance. Not needed if 'deriveFamily' -- is used. deriveFam :: RQ [Dec] deriveFam = do fcs <- liftM concat $ foreachTypeNumbered mkFrom tcs <- foreachTypeNumbered mkTo s <- indexGadtType return [ InstanceD [] (ConT ''Fam @@ s) [FunD 'from fcs, FunD 'to tcs] ] -- | Derive only the 'EqS' instance. Not needed if 'deriveFamily' -- is used. deriveEqS :: RQ [Dec] deriveEqS = do s <- indexGadtType ns <- Map.elems <$> asks familyTypes return [ InstanceD [] (ConT ''EqS @@ s) [FunD 'eqS (trues ns ++ falses ns)] ] where trueClause n = sClause [ConP (mkName n) [], ConP (mkName n) []] ((ConE 'Just `AppE` ConE 'Refl)) falseClause = sClause [WildP, WildP] ((ConE 'Nothing)) trues ns = fmap trueClause ns falses ns = if length (trues ns) == 1 then [] else [falseClause] constrInstance :: (AppliedTyCon,String) -> RQ [Dec] constrInstance (atc,s) = do cs <- liftq (atc2constructors atc) -- runIO (print i) ds <- mapM (mkData s) cs is <- mapM (mkInstance s) cs return $ ds ++ is stripRecordNames :: Con -> Con stripRecordNames (RecC n f) = NormalC n (fmap (\(_, s, t) -> (s, t)) f) stripRecordNames c = c -- TODO: Handle colons in the constructor name mkData :: String -> Con -> RQ Dec mkData s (NormalC n _) = do modifier <- asks constructorNameModifier liftq $ dataD (cxt []) (mkName . modifier s . cleanConstructorName . nameBase $ n) [] [] [] mkData s r@(RecC _ _) = mkData s (stripRecordNames r) mkData s (InfixC t1 n t2) = mkData s (NormalC n [t1,t2]) instance Lift Fixity where lift Prefix = conE 'Prefix lift (Infix a n) = conE 'Infix `appE` [| a |] `appE` [| n |] instance Lift Associativity where lift LeftAssociative = conE 'LeftAssociative lift RightAssociative = conE 'RightAssociative lift NotAssociative = conE 'NotAssociative mkInstance :: String -> Con -> RQ Dec mkInstance s (NormalC n _) = do modifier <- asks constructorNameModifier let n' = modifier s . cleanConstructorName . nameBase $ n liftq $ instanceD (cxt []) (appT (conT ''Constructor) (conT . mkName $ n')) [funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []]] mkInstance s r@(RecC _ _) = mkInstance s (stripRecordNames r) mkInstance s (InfixC t1 n t2) = do modifier <- asks constructorNameModifier let n' = modifier s . cleanConstructorName . nameBase $ n i <- liftq (reify n) let fi = case i of DataConI _ _ _ f -> convertFixity f _ -> Prefix liftq $ instanceD (cxt []) (appT (conT ''Constructor) (conT $ mkName n')) [funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []], funD 'conFixity [clause [wildP] (normalB [| fi |]) []]] where convertFixity (Fixity n d) = Infix (convertDirection d) n convertDirection InfixL = LeftAssociative convertDirection InfixR = RightAssociative convertDirection InfixN = NotAssociative pfType :: (AppliedTyCon,String) -> RQ Type pfType (atc,s) = do -- runIO $ putStrLn $ "processing " ++ show n cs <- liftq (atc2constructors atc) guardEmptyData cs atc b <- sumT <$> mapM (pfCon s) cs return $ ConT ''(:>:) @@ b @@ fromAppliedTyCon atc pfCon :: String -> Con -> RQ Type pfCon s (NormalC n fs) = do modifier <- asks constructorNameModifier let n' = mkName . modifier s . cleanConstructorName . nameBase $ n fieldResults <- mapM (pfField . snd) fs let rest = case fs of [] -> ConT ''U _ -> prodT fieldResults return $ ConT ''C @@ ConT n' @@ rest pfCon s r@(RecC _ _) = pfCon s (stripRecordNames r) pfCon s (InfixC t1 n t2) = pfCon s (NormalC n [t1,t2]) pfField :: Type -> RQ Type pfField t = ifInFamily t (ConT ''I @@ t) (ConT ''K @@ t) lookupFam :: Type -> RQ (Maybe String) lookupFam t = do ts <- asks familyTypes t' <- liftq $ toAppliedTyCon t let res = case t' of Right t'' -> Map.lookup t'' ts Left _ -> Nothing -- message ("familyTypes = "++show ts) -- message ("lookupFam "++show t'++" = "++show res) return res ifInFamily :: Type -> a -> a -> RQ a ifInFamily n x y = ifInFamily' n (return x) (return y) ifInFamily' :: Type -> RQ a -> RQ a -> RQ a ifInFamily' t x y = maybe y (const x) =<< lookupFam t elInstance :: (AppliedTyCon,String) -> RQ Dec elInstance x@(atc,_) = do s <- indexGadtType prf <- mkProof x return $ InstanceD [] (ConT ''El @@ s @@ fromAppliedTyCon atc) [prf] mkFrom :: Int -> Int -> (AppliedTyCon,String) -> RQ [Clause] mkFrom m i (atc,s) = do -- ns <- fmap mkName . elems <$> asks familyTypes -- runIO $ putStrLn $ "processing " ++ show n cs <- liftq (atc2constructors atc) let wrapE = (\e -> lrE m i (ConE 'Tag @@@ e)) dn = mkName s -- (nameBase n) zipWithM (fromCon wrapE dn (length cs)) [0..] cs mkTo :: Int -> Int -> (AppliedTyCon,String) -> RQ Clause mkTo m i (atc,s) = do -- ns <- fmap mkName . elems <$> asks familyTypes -- runIO $ putStrLn $ "processing " ++ show n cs <- liftq (atc2constructors atc) pfname <- mkName <$> asks patternFunctorName let -- typeOfLamE = ArrowT @@ -- (ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc) @@ -- (fromAppliedTyCon atc) matchesOfCons <- zipWithM (toCon (length cs) atc) [0..] cs xvar <- liftq (newName "x") convar <- liftq (newName "con") typeOfConvar <- do t0 <- pfType (atc,s) return (t0 @@ ConT ''I0 @@ fromAppliedTyCon atc) let typeOfXvar = ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc body = LamE [VarP xvar] (CaseE (VarE xvar `SigE` typeOfXvar) [sMatch (lrP m i (VarP convar)) (CaseE (VarE convar `SigE` typeOfConvar) matchesOfCons) ] ) return (sClause [ConP (mkName s) []] body ) mkProof :: (AppliedTyCon,String) -> RQ Dec mkProof (_,s) = return $ FunD 'proof [sClause [] (ConE (mkName s)) ] fromCon :: (Exp -> Exp) -> Name -> Int -> Int -> Con -> RQ Clause fromCon wrap n m i (NormalC cn []) = return $ -- Nullary constructor case sClause [ConP n [], ConP cn []] (wrap . lrE m i $ ConE 'C @@@ ConE 'U) fromCon wrap n m i (NormalC cn fs) = do rhs <- zipWithM fromField [0..] (snd <$> fs) return $ sClause [ ConP n [], ConP cn (fmap (VarP . field) [0..length fs - 1]) ] (wrap . lrE m i $ ConE 'C @@@ foldr1 prod rhs) where prod x y = ConE '(:*:) @@@ x @@@ y fromCon wrap n m i r@(RecC _ _) = fromCon wrap n m i (stripRecordNames r) fromCon wrap n m i (InfixC t1 cn t2) = fromCon wrap n m i (NormalC cn [t1,t2]) toCon :: Int -- ^ Number of constructors -> AppliedTyCon -> Int -- ^ Index of this constructor -> Con -> RQ Match toCon m atc i (NormalC cn []) = return $ -- Nullary constructor case sMatch (ConP 'Tag [lrP m i $ ConP 'C [ConP 'U []]]) ( ConE cn -- SigE (ConE cn) (fromAppliedTyCon atc) ) toCon m atc i (NormalC cn fs) = -- runIO (putStrLn ("constructor " ++ show ix)) >> do lhs <- zipWithM toField [0..] (fmap snd fs) return $ sMatch (ConP 'Tag [lrP m i $ ConP 'C [foldr1 prod lhs]]) ( -- SigE ( foldl AppE (ConE cn) (fmap (VarE . field) [0..length fs - 1]) -- ) -- (fromAppliedTyCon atc) ) where prod x y = ConP '(:*:) [x,y] toCon m atc i r@(RecC _ _) = toCon m atc i (stripRecordNames r) toCon m atc i (InfixC t1 cn t2) = toCon m atc i (NormalC cn [t1,t2]) fromField :: Int -> Type -> RQ Exp fromField nr t = ifInFamily' t (return (ConE 'I @@@ (ConE 'I0 @@@ VarE (field nr)))) (message ("* Info: Type not in family: " ++ pprintUnqual t) >> -- helper t >> return (ConE 'K @@@ VarE (field nr))) toField :: Int -> Type -> RQ Pat toField nr t = ifInFamily t (ConP 'I [ConP 'I0 [VarP (field nr)]]) (ConP 'K [VarP (field nr)]) field :: Int -> Name field n = mkName $ "f" ++ show n lrP :: Int -> Int -> ( Pat -> Pat) lrP m i p | bALANCED_MODE = ascendFromLeaf (ConP 'L . (:[] {- robot monkey -})) (ConP 'R . (:[])) p m i lrP 1 0 p = p lrP m 0 p = ConP 'L [p] lrP m i p = ConP 'R [lrP (m-1) (i-1) p] lrE :: Int -> Int -> ( Exp -> Exp) lrE m i e | bALANCED_MODE = ascendFromLeaf (ConE 'L @@@) (ConE 'R @@@) e m i lrE 1 0 e = e lrE m 0 e = ConE 'L @@@ e lrE m i e = ConE 'R @@@ lrE (m-1) (i-1) e guardEmptyData :: [Con] -> AppliedTyCon -> RQ () guardEmptyData [] atc = fail ("Empty types not supported yet ("++ show (fromAppliedTyCon atc)) guardEmptyData _ atc = return () -- helper t = do -- Right (AppliedTyCon n args) <- liftq (toAppliedTyCon t) -- let prefix = "Prf_" -- str <- if n == ''[] -- then do -- Right (AppliedTyCon n1 _) <- liftq (toAppliedTyCon (head args)) -- return ("T("++prefix++"List"++nameBase n1 -- ++",["++pprintUnqual (head args)++"])") -- else -- return ("T("++prefix++nameBase n -- ++","++pprintUnqual t++")") -- liftq . runIO $ appendFile "dump.dump" (str++"\n") noSigE :: Exp -> Type -> Exp x `noSigE` y = x -- makeSanityChecks :: RQ [Dec] -- makeSanityChecks = concat <$> foreachType makeSanityCheck -- makeSanityCheck :: (AppliedTyCon,String) -> RQ [Dec] -- makeSanityCheck (atc,s) = do -- famname <- mkName <$> asks indexGadtName -- let -- chkName = mkName ("sanityCheck"++s) -- return [ -- SigD chkName (ConT famname @@ fromAppliedTyCon atc) -- , ValD (VarP chkName) -- (NormalB (ConE (mkName s))) -- [] -- ]