{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}
module Grisette.Core.TH
(
makeUnionWrapper,
makeUnionWrapper',
)
where
import Control.Monad (join, replicateM, when, zipWithM)
import Grisette.Core.THCompat (augmentFinalType)
import Language.Haskell.TH
( Body (NormalB),
Clause (Clause),
Con (ForallC, GadtC, InfixC, NormalC, RecC, RecGadtC),
Dec (DataD, FunD, NewtypeD, SigD),
Exp (AppE, ConE, LamE, VarE),
Info (DataConI, TyConI),
Name,
Pat (VarP),
Q,
Type (ForallT),
mkName,
newName,
pprint,
reify,
)
import Language.Haskell.TH.Syntax (Name (Name), OccName (OccName))
makeUnionWrapper' ::
[String] ->
Name ->
Q [Dec]
makeUnionWrapper' :: [String] -> Name -> Q [Dec]
makeUnionWrapper' [String]
names Name
typName = do
[Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
constructors) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
[[Dec]]
ds <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Con -> Q [Dec]
mkSingleWrapper [String]
names [Con]
constructors
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds
occName :: Name -> String
occName :: Name -> String
occName (Name (OccName String
name) NameFlavour
_) = String
name
getConstructorName :: Con -> Q String
getConstructorName :: Con -> Q String
getConstructorName (NormalC Name
name [BangType]
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecC Name
name [VarBangType]
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName InfixC {} =
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"You should use makeUnionWrapper' to manually provide the name for infix constructors"
getConstructorName (ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
c) = Con -> Q String
getConstructorName Con
c
getConstructorName (GadtC [Name
name] [BangType]
_ Type
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecGadtC [Name
name] [VarBangType]
_ Type
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName Con
c = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor at this time: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Con
c
getConstructors :: Name -> Q [Con]
getConstructors :: Name -> Q [Con]
getConstructors Name
typName = do
Info
d <- Name -> Q Info
reify Name
typName
case Info
d of
TyConI (DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ [Con]
constructors [DerivClause]
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return [Con]
constructors
TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ Con
constructor [DerivClause]
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return [Con
constructor]
Info
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported declaration: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Info
d
makeUnionWrapper ::
String ->
Name ->
Q [Dec]
makeUnionWrapper :: String -> Name -> Q [Dec]
makeUnionWrapper String
prefix Name
typName = do
[Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
[String]
constructorNames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Con -> Q String
getConstructorName [Con]
constructors
[String] -> Name -> Q [Dec]
makeUnionWrapper' ((String
prefix forall a. [a] -> [a] -> [a]
++) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
constructorNames) Name
typName
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr Int
n Exp
f = do
[Name]
xs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
let args :: [Pat]
args = forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs
Exp
mrgSingleFun <- [|mrgSingle|]
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
[Pat] -> Exp -> Exp
LamE
[Pat]
args
( Exp -> Exp -> Exp
AppE Exp
mrgSingleFun forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
f (forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
xs)
)
augmentNormalCType :: Type -> Q Type
augmentNormalCType :: Type -> Q Type
augmentNormalCType (ForallT [TyVarBndr Specificity]
tybinders Cxt
ctx Type
ty1) = do
(([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
ty1
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT ([TyVarBndr Specificity]
bndrs forall a. [a] -> [a] -> [a]
++ [TyVarBndr Specificity]
tybinders) (Cxt
preds forall a. [a] -> [a] -> [a]
++ Cxt
ctx) Type
augmentedTyp
augmentNormalCType Type
t = do
(([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
preds Type
augmentedTyp
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper String
name (NormalC Name
oriName [BangType]
b) = do
DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
let retName :: Name
retName = String -> Name
mkName String
name
Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr (forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
b) (Name -> Exp
ConE Name
oriName)
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
]
mkSingleWrapper String
name (RecC Name
oriName [VarBangType]
b) = do
DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
let retName :: Name
retName = String -> Name
mkName String
name
Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
b) (Name -> Exp
ConE Name
oriName)
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
]
mkSingleWrapper String
_ Con
v = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor" forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Con
v