module Generics.MultiRec.TH.Alt
(
DerivOptions(..),
deriveEverything,
) where
import Generics.MultiRec.TH.Alt.DerivOptions(DerivOptions(..))
import THUtils(AppliedTyCon, (@@), (@@@), toAppliedTyCon,
fromAppliedTyCon, atc2constructors, pprintUnqual, sMatch, sClause,
cleanConstructorName)
import BalancedFold(balancedFold, ascendFromLeaf)
import MonadRQ(RQ, message, messageReport, liftq, foreachType,
foreachTypeNumbered, runRQ)
import Generics.MultiRec.Base((:>:)(..), C(..), El(..), (:=:)(..),
I0(I0), (:*:)(..), (:+:)(..), EqS(..), Fam(..), I(I), K(K), U(..))
import Generics.MultiRec.Constructor(Associativity(..), Fixity(..),
Constructor(..))
import Control.Monad.Reader(Monad(return, fail, (>>)), Functor(..),
(=<<), mapM, sequence, liftM, zipWithM, asks)
import Language.Haskell.TH.Syntax(Lift(..))
import Language.Haskell.TH(newName, mkName, wildP, clause, conE,
appE, normalB, funD, dataD, instanceD, cxt, conT, appT,
Exp(VarE, SigE, LamE, ConE, CaseE, AppE), Match, Clause, Q,
Pat(WildP, VarP, ConP), TypeQ, Type(ConT),
Dec(TySynD, InstanceD, FunD), Name,
Con(RecC, NormalC, InfixC),
FixityDirection(..),
Info(DataConI), nameBase, reify, stringE)
import Data.Map(lookup, elems)
import Control.Applicative((<$>))
import qualified Data.Map as Map
import qualified Language.Haskell.TH as TH
bALANCED_MODE :: Bool
bALANCED_MODE = False
deriveEverything :: DerivOptions [(TypeQ, String)] -> Q [Dec]
deriveEverything opts = do
runRQ (concat <$> sequence [deriveConstructors, deriveFamily]) opts
deriveConstructors :: RQ [Dec]
deriveConstructors =
concat <$> foreachType constrInstance
deriveFamily :: RQ [Dec]
deriveFamily =
do
pf <- derivePF
el <- deriveEl
fam <- deriveFam
eq <- deriveEqS
return $ pf ++ el ++ fam ++ eq
derivePF :: RQ [Dec]
derivePF =
do
branches <- foreachType pfType
pfn <- asks patternFunctorName
let
pf = [TySynD (mkName pfn) [] (sumT branches)]
famName <- asks indexGadtName
messageReport (
"Reminder: Don't forget to add this line manually:\n"
++ " type instance PF "++famName++" = "++pfn
)
return pf
sumT :: [Type] -> Type
sumT | bALANCED_MODE = balancedSumT
| otherwise = rightSumT
rightSumT :: [Type] -> Type
rightSumT = foldr1 plusT
balancedSumT :: [Type] -> Type
balancedSumT = balancedFold plusT
plusT :: Type -> Type -> Type
plusT a b = ConT ''(:+:) @@ a @@ b
prodT :: [Type] -> Type
prodT = foldr1 timesT
timesT :: Type -> Type -> Type
timesT a b = ConT ''(:*:) @@ a @@ b
deriveEl :: RQ [Dec]
deriveEl = foreachType elInstance
indexGadtType :: RQ Type
indexGadtType = ConT . mkName <$> asks indexGadtName
deriveFam :: RQ [Dec]
deriveFam =
do
fcs <- liftM concat $ foreachTypeNumbered mkFrom
tcs <- foreachTypeNumbered mkTo
s <- indexGadtType
return [
InstanceD [] (ConT ''Fam @@ s)
[FunD 'from fcs, FunD 'to tcs]
]
deriveEqS :: RQ [Dec]
deriveEqS = do
s <- indexGadtType
ns <- elems <$> asks familyTypes
return [
InstanceD [] (ConT ''EqS @@ s)
[FunD 'eqS (trues ns ++ falses ns)]
]
where
trueClause n = sClause [ConP (mkName n) [], ConP (mkName n) []]
((ConE 'Just `AppE` ConE 'Refl))
falseClause = sClause [WildP, WildP]
((ConE 'Nothing))
trues ns = fmap trueClause ns
falses ns = if length (trues ns) == 1 then [] else [falseClause]
constrInstance :: (AppliedTyCon,String) -> RQ [Dec]
constrInstance (atc,s) =
do
cs <- liftq (atc2constructors atc)
ds <- mapM (mkData s) cs
is <- mapM (mkInstance s) cs
return $ ds ++ is
stripRecordNames :: Con -> Con
stripRecordNames (RecC n f) =
NormalC n (fmap (\(_, s, t) -> (s, t)) f)
stripRecordNames c = c
mkData :: String -> Con -> RQ Dec
mkData s (NormalC n _) = do
modifier <- asks constructorNameModifier
liftq $ dataD (cxt []) (mkName . modifier s . cleanConstructorName . nameBase $ n) [] [] []
mkData s r@(RecC _ _) =
mkData s (stripRecordNames r)
mkData s (InfixC t1 n t2) =
mkData s (NormalC n [t1,t2])
instance Lift Fixity where
lift Prefix = conE 'Prefix
lift (Infix a n) = conE 'Infix `appE` [| a |] `appE` [| n |]
instance Lift Associativity where
lift LeftAssociative = conE 'LeftAssociative
lift RightAssociative = conE 'RightAssociative
lift NotAssociative = conE 'NotAssociative
mkInstance :: String -> Con -> RQ Dec
mkInstance s (NormalC n _) =
do
modifier <- asks constructorNameModifier
let n' = modifier s . cleanConstructorName . nameBase $ n
liftq $
instanceD (cxt [])
(appT (conT ''Constructor) (conT . mkName $ n'))
[funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []]]
mkInstance s r@(RecC _ _) =
mkInstance s (stripRecordNames r)
mkInstance s (InfixC t1 n t2) =
do
modifier <- asks constructorNameModifier
let n' = modifier s . cleanConstructorName . nameBase $ n
i <- liftq (reify n)
let fi = case i of
DataConI _ _ _ f -> convertFixity f
_ -> Prefix
liftq $
instanceD (cxt []) (appT (conT ''Constructor) (conT $ mkName n'))
[funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []],
funD 'conFixity [clause [wildP] (normalB [| fi |]) []]]
where
convertFixity (TH.Fixity n d) = Infix (convertDirection d) n
convertDirection InfixL = LeftAssociative
convertDirection InfixR = RightAssociative
convertDirection InfixN = NotAssociative
pfType :: (AppliedTyCon,String) -> RQ Type
pfType (atc,s) =
do
cs <- liftq (atc2constructors atc)
guardEmptyData cs atc
b <- sumT <$> mapM (pfCon s) cs
return $
ConT ''(:>:) @@ b @@ fromAppliedTyCon atc
pfCon :: String -> Con -> RQ Type
pfCon s (NormalC n fs) = do
modifier <- asks constructorNameModifier
let n' = mkName . modifier s . cleanConstructorName . nameBase $ n
fieldResults <- mapM (pfField . snd) fs
let rest =
case fs of
[] -> ConT ''U
_ -> prodT fieldResults
return $
ConT ''C @@ ConT n' @@ rest
pfCon s r@(RecC _ _) =
pfCon s (stripRecordNames r)
pfCon s (InfixC t1 n t2) =
pfCon s (NormalC n [t1,t2])
pfField :: Type -> RQ Type
pfField t = ifInFamily t (ConT ''I @@ t) (ConT ''K @@ t)
lookupFam :: Type -> RQ (Maybe String)
lookupFam t =
do
ts <- asks familyTypes
t' <- liftq $ toAppliedTyCon t
let res = case t' of
Right t'' -> Map.lookup t'' ts
Left _ -> Nothing
return res
ifInFamily :: Type -> a -> a -> RQ a
ifInFamily n x y = ifInFamily' n (return x) (return y)
ifInFamily' :: Type -> RQ a -> RQ a -> RQ a
ifInFamily' t x y = maybe y (const x)
=<< lookupFam t
elInstance :: (AppliedTyCon,String) -> RQ Dec
elInstance x@(atc,_) = do
s <- indexGadtType
prf <- mkProof x
return $ InstanceD [] (ConT ''El @@ s @@ fromAppliedTyCon atc) [prf]
mkFrom :: Int -> Int -> (AppliedTyCon,String) -> RQ [Clause]
mkFrom m i (atc,s) =
do
cs <- liftq (atc2constructors atc)
let
wrapE = (\e -> lrE m i (ConE 'Tag @@@ e))
dn = mkName s
zipWithM (fromCon wrapE dn (length cs)) [0..] cs
mkTo :: Int -> Int -> (AppliedTyCon,String) -> RQ Clause
mkTo m i (atc,s) =
do
cs <- liftq (atc2constructors atc)
pfname <- mkName <$> asks patternFunctorName
let
matchesOfCons <-
zipWithM (toCon (length cs) atc) [0..] cs
xvar <- liftq (newName "x")
convar <- liftq (newName "con")
typeOfConvar <-
do
t0 <- pfType (atc,s)
return (t0 @@ ConT ''I0 @@ fromAppliedTyCon atc)
let
typeOfXvar = ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc
body = LamE [VarP xvar]
(CaseE (VarE xvar `SigE` typeOfXvar)
[sMatch (lrP m i (VarP convar))
(CaseE (VarE convar `SigE` typeOfConvar)
matchesOfCons)
]
)
return (sClause
[ConP (mkName s) []]
body
)
mkProof :: (AppliedTyCon,String) -> RQ Dec
mkProof (_,s) = return $
FunD 'proof [sClause [] (ConE (mkName s)) ]
fromCon :: (Exp -> Exp) -> Name -> Int -> Int -> Con -> RQ Clause
fromCon wrap n m i (NormalC cn []) = return $
sClause
[ConP n [], ConP cn []]
(wrap . lrE m i
$ ConE 'C @@@ ConE 'U)
fromCon wrap n m i (NormalC cn fs) =
do
rhs <- zipWithM fromField [0..] (snd <$> fs)
return $
sClause
[ ConP n [],
ConP cn (fmap (VarP . field) [0..length fs 1])
]
(wrap . lrE m i
$ ConE 'C @@@ foldr1 prod rhs)
where
prod x y = ConE '(:*:) @@@ x @@@ y
fromCon wrap n m i r@(RecC _ _) =
fromCon wrap n m i (stripRecordNames r)
fromCon wrap n m i (InfixC t1 cn t2) =
fromCon wrap n m i (NormalC cn [t1,t2])
toCon ::
Int
-> AppliedTyCon
-> Int
-> Con
-> RQ Match
toCon m atc i (NormalC cn []) = return $
sMatch
(ConP 'Tag [lrP m i $ ConP 'C [ConP 'U []]])
(
ConE cn
)
toCon m atc i (NormalC cn fs) =
do
lhs <- zipWithM toField [0..] (fmap snd fs)
return $
sMatch
(ConP 'Tag [lrP m i $ ConP 'C [foldr1 prod lhs]])
(
foldl AppE (ConE cn)
(fmap (VarE . field) [0..length fs 1])
)
where
prod x y = ConP '(:*:) [x,y]
toCon m atc i r@(RecC _ _) =
toCon m atc i (stripRecordNames r)
toCon m atc i (InfixC t1 cn t2) =
toCon m atc i (NormalC cn [t1,t2])
fromField :: Int -> Type -> RQ Exp
fromField nr t =
ifInFamily' t
(return (ConE 'I @@@ (ConE 'I0 @@@ VarE (field nr))))
(message ("* Info: Type not in family: " ++ pprintUnqual t) >>
return (ConE 'K @@@ VarE (field nr)))
toField :: Int -> Type -> RQ Pat
toField nr t =
ifInFamily t
(ConP 'I [ConP 'I0 [VarP (field nr)]])
(ConP 'K [VarP (field nr)])
field :: Int -> Name
field n = mkName $ "f" ++ show n
lrP :: Int -> Int -> ( Pat -> Pat)
lrP m i p | bALANCED_MODE = ascendFromLeaf
(ConP 'L . (:[] ))
(ConP 'R . (:[]))
p
m
i
lrP 1 0 p = p
lrP m 0 p = ConP 'L [p]
lrP m i p = ConP 'R [lrP (m1) (i1) p]
lrE :: Int -> Int -> ( Exp -> Exp)
lrE m i e | bALANCED_MODE = ascendFromLeaf
(ConE 'L @@@)
(ConE 'R @@@)
e
m
i
lrE 1 0 e = e
lrE m 0 e = ConE 'L @@@ e
lrE m i e = ConE 'R @@@ lrE (m1) (i1) e
guardEmptyData :: [Con] -> AppliedTyCon -> RQ ()
guardEmptyData [] atc = fail ("Empty types not supported yet ("++
show (fromAppliedTyCon atc))
guardEmptyData _ atc = return ()
noSigE :: Exp -> Type -> Exp
x `noSigE` y = x