module Generics.MultiRec.TH.Alt
(
DerivOptions(..),
deriveEverything,
) where
import Generics.MultiRec.Base
import Generics.MultiRec.Constructor
import Language.Haskell.TH hiding (Fixity())
import Language.Haskell.TH.Syntax (Lift(..))
import Language.Haskell.TH.ExpandSyns
import Language.Haskell.TH.Ppr
import Control.Monad
import Control.Monad.Reader hiding (lift)
import qualified Control.Monad.Reader as Reader
import Control.Applicative
import Control.Arrow
import THUtils
import Data.Map as Map
import Data.Set as Set hiding(elems)
import qualified Data.Foldable as Fold
import BalancedFold
bALANCED_MODE :: Bool
bALANCED_MODE = True
data DerivOptions ft =
DerivOptions {
familyTypes :: ft
, indexGadtName :: String
, constructorNameModifier :: String -> String -> String
, patternFunctorName :: String
, verbose :: Bool
}
instance Functor DerivOptions where
fmap f d = d { familyTypes = (f . familyTypes) d }
cleanConstructorName :: [Char] -> [Char]
cleanConstructorName c =
if head c == '(' && last c == ')' then ("TUPLE"++show (length c1))
else if c=="[]" then "NIL"
else if c==":" then "CONS"
else c
message :: String -> RQ ()
message x =
do
b <- asks verbose
when b (liftq . runIO . putStrLn $ x ++ "\n")
messageReport :: String -> RQ ()
messageReport x =
do
b <- asks verbose
when b (liftq . report False $ x ++ "\n")
type RQ = ReaderT (DerivOptions (Map AppliedTyCon String)) Q
liftq :: Q a -> RQ a
liftq = Reader.lift
foreachType :: ((AppliedTyCon,String) -> RQ a) -> RQ [a]
foreachType f = mapM f . Map.toList =<< asks familyTypes
foreachTypeNumbered :: (Int -> Int -> (AppliedTyCon,String) -> RQ a) -> RQ [a]
foreachTypeNumbered f = do
ns <- Map.toList <$> asks familyTypes
zipWithM (f (length ns)) [0..] ns
collision :: AppliedTyCon -> String -> String -> a
collision k a b = error ("collision : "
++ "\n key = "++pprintUnqual k
++ "\n values = "++show(a,b)
)
runRQ :: RQ a -> DerivOptions [(TypeQ,String)] -> Q a
runRQ x opts = do
ft' <- sequence
. fmap (\(x,y) -> x >>= (\x' -> return (x',y)))
. familyTypes $ opts
:: Q [(Type,String)]
when (Prelude.null ft') (fail ("Empty family not supported."))
ft'' <- mapM (\(t,s) -> do
t' <- toAppliedTyCon t
case t' of
Left err -> fail err
Right t'' -> return (t'',s))
ft'
let
ft''' = Map.fromListWithKey collision ft''
runReaderT x (fmap (const ft''') opts)
deriveEverything :: DerivOptions [(TypeQ, String)] -> Q [Dec]
deriveEverything opts = do
runRQ (concat <$> sequence [deriveConstructors, deriveFamily]) opts
deriveConstructors :: RQ [Dec]
deriveConstructors =
concat <$> foreachType constrInstance
deriveFamily :: RQ [Dec]
deriveFamily =
do
pf <- derivePF
el <- deriveEl
fam <- deriveFam
eq <- deriveEqS
return $ pf ++ el ++ fam ++ eq
derivePF :: RQ [Dec]
derivePF =
do
branches <- foreachType pfType
pfn <- asks patternFunctorName
let
pf = [TySynD (mkName pfn) [] (sumT branches)]
famName <- asks indexGadtName
messageReport (
"Reminder: Don't forget to add this line manually:\n"
++ " type instance PF "++famName++" = "++pfn
)
return pf
sumT :: [Type] -> Type
sumT | bALANCED_MODE = balancedSumT
| otherwise = rightSumT
rightSumT :: [Type] -> Type
rightSumT = foldr1 plusT
balancedSumT :: [Type] -> Type
balancedSumT = balancedFold plusT
plusT :: Type -> Type -> Type
plusT a b = ConT ''(:+:) @@ a @@ b
prodT :: [Type] -> Type
prodT = foldr1 timesT
timesT :: Type -> Type -> Type
timesT a b = ConT ''(:*:) @@ a @@ b
deriveEl :: RQ [Dec]
deriveEl = foreachType elInstance
indexGadtType :: RQ Type
indexGadtType = ConT . mkName <$> asks indexGadtName
deriveFam :: RQ [Dec]
deriveFam =
do
fcs <- liftM concat $ foreachTypeNumbered mkFrom
tcs <- foreachTypeNumbered mkTo
s <- indexGadtType
return [
InstanceD [] (ConT ''Fam @@ s)
[FunD 'from fcs, FunD 'to tcs]
]
deriveEqS :: RQ [Dec]
deriveEqS = do
s <- indexGadtType
ns <- Map.elems <$> asks familyTypes
return [
InstanceD [] (ConT ''EqS @@ s)
[FunD 'eqS (trues ns ++ falses ns)]
]
where
trueClause n = sClause [ConP (mkName n) [], ConP (mkName n) []]
((ConE 'Just `AppE` ConE 'Refl))
falseClause = sClause [WildP, WildP]
((ConE 'Nothing))
trues ns = fmap trueClause ns
falses ns = if length (trues ns) == 1 then [] else [falseClause]
constrInstance :: (AppliedTyCon,String) -> RQ [Dec]
constrInstance (atc,s) =
do
cs <- liftq (atc2constructors atc)
ds <- mapM (mkData s) cs
is <- mapM (mkInstance s) cs
return $ ds ++ is
stripRecordNames :: Con -> Con
stripRecordNames (RecC n f) =
NormalC n (fmap (\(_, s, t) -> (s, t)) f)
stripRecordNames c = c
mkData :: String -> Con -> RQ Dec
mkData s (NormalC n _) = do
modifier <- asks constructorNameModifier
liftq $ dataD (cxt []) (mkName . modifier s . cleanConstructorName . nameBase $ n) [] [] []
mkData s r@(RecC _ _) =
mkData s (stripRecordNames r)
mkData s (InfixC t1 n t2) =
mkData s (NormalC n [t1,t2])
instance Lift Fixity where
lift Prefix = conE 'Prefix
lift (Infix a n) = conE 'Infix `appE` [| a |] `appE` [| n |]
instance Lift Associativity where
lift LeftAssociative = conE 'LeftAssociative
lift RightAssociative = conE 'RightAssociative
lift NotAssociative = conE 'NotAssociative
mkInstance :: String -> Con -> RQ Dec
mkInstance s (NormalC n _) =
do
modifier <- asks constructorNameModifier
let n' = modifier s . cleanConstructorName . nameBase $ n
liftq $
instanceD (cxt [])
(appT (conT ''Constructor) (conT . mkName $ n'))
[funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []]]
mkInstance s r@(RecC _ _) =
mkInstance s (stripRecordNames r)
mkInstance s (InfixC t1 n t2) =
do
modifier <- asks constructorNameModifier
let n' = modifier s . cleanConstructorName . nameBase $ n
i <- liftq (reify n)
let fi = case i of
DataConI _ _ _ f -> convertFixity f
_ -> Prefix
liftq $
instanceD (cxt []) (appT (conT ''Constructor) (conT $ mkName n'))
[funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []],
funD 'conFixity [clause [wildP] (normalB [| fi |]) []]]
where
convertFixity (Fixity n d) = Infix (convertDirection d) n
convertDirection InfixL = LeftAssociative
convertDirection InfixR = RightAssociative
convertDirection InfixN = NotAssociative
pfType :: (AppliedTyCon,String) -> RQ Type
pfType (atc,s) =
do
cs <- liftq (atc2constructors atc)
guardEmptyData cs atc
b <- sumT <$> mapM (pfCon s) cs
return $
ConT ''(:>:) @@ b @@ fromAppliedTyCon atc
pfCon :: String -> Con -> RQ Type
pfCon s (NormalC n fs) = do
modifier <- asks constructorNameModifier
let n' = mkName . modifier s . cleanConstructorName . nameBase $ n
fieldResults <- mapM (pfField . snd) fs
let rest =
case fs of
[] -> ConT ''U
_ -> prodT fieldResults
return $
ConT ''C @@ ConT n' @@ rest
pfCon s r@(RecC _ _) =
pfCon s (stripRecordNames r)
pfCon s (InfixC t1 n t2) =
pfCon s (NormalC n [t1,t2])
pfField :: Type -> RQ Type
pfField t = ifInFamily t (ConT ''I @@ t) (ConT ''K @@ t)
lookupFam :: Type -> RQ (Maybe String)
lookupFam t =
do
ts <- asks familyTypes
t' <- liftq $ toAppliedTyCon t
let res = case t' of
Right t'' -> Map.lookup t'' ts
Left _ -> Nothing
return res
ifInFamily :: Type -> a -> a -> RQ a
ifInFamily n x y = ifInFamily' n (return x) (return y)
ifInFamily' :: Type -> RQ a -> RQ a -> RQ a
ifInFamily' t x y = maybe y (const x)
=<< lookupFam t
elInstance :: (AppliedTyCon,String) -> RQ Dec
elInstance x@(atc,_) = do
s <- indexGadtType
prf <- mkProof x
return $ InstanceD [] (ConT ''El @@ s @@ fromAppliedTyCon atc) [prf]
mkFrom :: Int -> Int -> (AppliedTyCon,String) -> RQ [Clause]
mkFrom m i (atc,s) =
do
cs <- liftq (atc2constructors atc)
let
wrapE = (\e -> lrE m i (ConE 'Tag @@@ e))
dn = mkName s
zipWithM (fromCon wrapE dn (length cs)) [0..] cs
mkTo :: Int -> Int -> (AppliedTyCon,String) -> RQ Clause
mkTo m i (atc,s) =
do
cs <- liftq (atc2constructors atc)
pfname <- mkName <$> asks patternFunctorName
let
matchesOfCons <-
zipWithM (toCon (length cs) atc) [0..] cs
xvar <- liftq (newName "x")
convar <- liftq (newName "con")
typeOfConvar <-
do
t0 <- pfType (atc,s)
return (t0 @@ ConT ''I0 @@ fromAppliedTyCon atc)
let
typeOfXvar = ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc
body = LamE [VarP xvar]
(CaseE (VarE xvar `SigE` typeOfXvar)
[sMatch (lrP m i (VarP convar))
(CaseE (VarE convar `SigE` typeOfConvar)
matchesOfCons)
]
)
return (sClause
[ConP (mkName s) []]
body
)
mkProof :: (AppliedTyCon,String) -> RQ Dec
mkProof (_,s) = return $
FunD 'proof [sClause [] (ConE (mkName s)) ]
fromCon :: (Exp -> Exp) -> Name -> Int -> Int -> Con -> RQ Clause
fromCon wrap n m i (NormalC cn []) = return $
sClause
[ConP n [], ConP cn []]
(wrap . lrE m i
$ ConE 'C @@@ ConE 'U)
fromCon wrap n m i (NormalC cn fs) =
do
rhs <- zipWithM fromField [0..] (snd <$> fs)
return $
sClause
[ ConP n [],
ConP cn (fmap (VarP . field) [0..length fs 1])
]
(wrap . lrE m i
$ ConE 'C @@@ foldr1 prod rhs)
where
prod x y = ConE '(:*:) @@@ x @@@ y
fromCon wrap n m i r@(RecC _ _) =
fromCon wrap n m i (stripRecordNames r)
fromCon wrap n m i (InfixC t1 cn t2) =
fromCon wrap n m i (NormalC cn [t1,t2])
toCon ::
Int
-> AppliedTyCon
-> Int
-> Con
-> RQ Match
toCon m atc i (NormalC cn []) = return $
sMatch
(ConP 'Tag [lrP m i $ ConP 'C [ConP 'U []]])
(
ConE cn
)
toCon m atc i (NormalC cn fs) =
do
lhs <- zipWithM toField [0..] (fmap snd fs)
return $
sMatch
(ConP 'Tag [lrP m i $ ConP 'C [foldr1 prod lhs]])
(
foldl AppE (ConE cn)
(fmap (VarE . field) [0..length fs 1])
)
where
prod x y = ConP '(:*:) [x,y]
toCon m atc i r@(RecC _ _) =
toCon m atc i (stripRecordNames r)
toCon m atc i (InfixC t1 cn t2) =
toCon m atc i (NormalC cn [t1,t2])
fromField :: Int -> Type -> RQ Exp
fromField nr t =
ifInFamily' t
(return (ConE 'I @@@ (ConE 'I0 @@@ VarE (field nr))))
(message ("* Info: Type not in family: " ++ pprintUnqual t) >>
return (ConE 'K @@@ VarE (field nr)))
toField :: Int -> Type -> RQ Pat
toField nr t =
ifInFamily t
(ConP 'I [ConP 'I0 [VarP (field nr)]])
(ConP 'K [VarP (field nr)])
field :: Int -> Name
field n = mkName $ "f" ++ show n
lrP :: Int -> Int -> ( Pat -> Pat)
lrP m i p | bALANCED_MODE = ascendFromLeaf
(ConP 'L . (:[] ))
(ConP 'R . (:[]))
p
m
i
lrP 1 0 p = p
lrP m 0 p = ConP 'L [p]
lrP m i p = ConP 'R [lrP (m1) (i1) p]
lrE :: Int -> Int -> ( Exp -> Exp)
lrE m i e | bALANCED_MODE = ascendFromLeaf
(ConE 'L @@@)
(ConE 'R @@@)
e
m
i
lrE 1 0 e = e
lrE m 0 e = ConE 'L @@@ e
lrE m i e = ConE 'R @@@ lrE (m1) (i1) e
guardEmptyData :: [Con] -> AppliedTyCon -> RQ ()
guardEmptyData [] atc = fail ("Empty types not supported yet ("++
show (fromAppliedTyCon atc))
guardEmptyData _ atc = return ()
noSigE :: Exp -> Type -> Exp
x `noSigE` y = x