module Control.Monad.Mock.TH (makeAction, deriveAction, ts) where
import Control.Monad (replicateM, when, zipWithM)
import Data.Char (toUpper)
import Data.Foldable (traverse_)
import Data.List (foldl', nub, partition)
import Data.Type.Equality ((:~:)(..))
import GHC.Exts (Constraint)
import Language.Haskell.TH
import Control.Monad.Mock (Action(..), MockT, mockAction)
import Control.Monad.Mock.TH.Internal.TypesQuasi (ts)
makeAction :: String -> Cxt -> Q [Dec]
makeAction actionNameStr classTs = do
traverse_ assertDerivableConstraint classTs
actionParamName <- newName "r"
let actionName = mkName actionNameStr
actionTypeCon = ConT actionName
classInfos <- traverse reify (map unappliedTypeName classTs)
methods <- traverse classMethods classInfos
actionCons <- concat <$> zipWithM (methodsToConstructors actionTypeCon) classTs methods
let actionDec = DataD [] actionName [PlainTV actionParamName] Nothing actionCons []
mkStandaloneDec derivT = standaloneDeriveD' [] (derivT `AppT` (actionTypeCon `AppT` VarT actionParamName))
standaloneDecs = [mkStandaloneDec (ConT ''Eq), mkStandaloneDec (ConT ''Show)]
actionInstanceDec <- deriveAction' actionTypeCon actionCons
classInstanceDecs <- zipWithM (mkInstance actionTypeCon) classTs methods
return $ [actionDec] ++ standaloneDecs ++ [actionInstanceDec] ++ classInstanceDecs
where
assertDerivableConstraint :: Type -> Q ()
assertDerivableConstraint classType = do
info <- reify $ unappliedTypeName classType
(ClassD _ _ classVars _ _) <- case info of
ClassI dec _ -> return dec
_ -> fail $ "makeAction: expected a constraint, given ‘" ++ show (ppr classType) ++ "’"
let classArgs = typeArgs classType
let mkClassKind vars = foldr (\a b -> AppT (AppT ArrowT a) b) (ConT ''Constraint) (reverse varKinds)
where varKinds = map (\(KindedTV _ k) -> k) vars
constraintStr = show (ppr (ConT ''Constraint))
when (length classArgs > length classVars) $
fail $ "makeAction: too many arguments for class\n"
++ " in: " ++ show (ppr classType) ++ "\n"
++ " for class of kind: " ++ show (ppr (mkClassKind classVars))
when (length classArgs == length classVars) $
fail $ "makeAction: cannot derive instance for fully saturated constraint\n"
++ " in: " ++ show (ppr classType) ++ "\n"
++ " expected: * -> " ++ constraintStr ++ "\n"
++ " given: " ++ constraintStr
when (length classArgs < length classVars 1) $
fail $ "makeAction: cannot derive instance for multi-parameter typeclass\n"
++ " in: " ++ show (ppr classType) ++ "\n"
++ " expected: * -> " ++ constraintStr ++ "\n"
++ " given: " ++ show (ppr (mkClassKind $ drop (length classArgs) classVars))
methodsToConstructors :: Type -> Type -> [Dec] -> Q [Con]
methodsToConstructors actionT classT = traverse (methodToConstructor actionT classT)
methodToConstructor :: Type -> Type -> Dec -> Q Con
methodToConstructor actionT classT (SigD name typ) = do
let constructorName = methodNameToConstructorName name
newT <- replaceClassConstraint classT actionT typ
let (tyVars, ctx, argTs, resultT) = splitFnType newT
noStrictness = Bang NoSourceUnpackedness NoSourceStrictness
gadtCon = GadtC [constructorName] (map (noStrictness,) argTs) resultT
return $ ForallC tyVars ctx gadtCon
methodToConstructor _ _ _ = fail "methodToConstructor: internal error; report a bug with the monad-mock package"
methodNameToConstructorName :: Name -> Name
methodNameToConstructorName name = mkName (toUpper c : cs)
where (c:cs) = nameBase name
mkInstance :: Type -> Type -> [Dec] -> Q Dec
mkInstance actionT classT methodSigs = do
mVar <- newName "m"
methodImpls <- traverse mkInstanceMethod methodSigs
let instanceHead = classT `AppT` (ConT ''MockT `AppT` actionT `AppT` VarT mVar)
return $ InstanceD Nothing [ConT ''Monad `AppT` VarT mVar] instanceHead methodImpls
mkInstanceMethod :: Dec -> Q Dec
mkInstanceMethod (SigD name typ) = do
let constructorName = methodNameToConstructorName name
arity = fnTypeArity typ
argNames <- replicateM arity (newName "x")
let pats = map VarP argNames
conCall = foldl' AppE (ConE constructorName) (map VarE argNames)
mockCall = VarE 'mockAction `AppE` LitE (StringL $ nameBase name) `AppE` conCall
return $ FunD name [Clause pats (NormalB mockCall) []]
mkInstanceMethod _ = fail "mkInstanceMethod: internal error; report a bug with the monad-mock package"
replaceClassConstraint :: Type -> Type -> Type -> Q Type
replaceClassConstraint classType replacementType (ForallT vars preds typ) =
let
unappliedClassType = unappliedType classType
classTypeArgs = typeArgs classType
([replacedPred], newPreds) = partition ((unappliedClassType ==) . unappliedType) preds
replacedVars = typeVarNames replacedPred
replacementTypes = classTypeArgs ++ [replacementType]
newVars = filter ((`notElem` replacedVars) . tyVarBndrName) vars
replacedT = foldl' (flip $ uncurry substituteTypeVar) typ (zip replacedVars replacementTypes)
in return $ ForallT newVars newPreds replacedT
replaceClassConstraint _ _ _ = fail "replaceClassConstraint: internal error; report a bug with the monad-mock package"
deriveAction :: Name -> Q [Dec]
deriveAction name = do
info <- reify name
(tyCon, dataCons) <- extractActionInfo info
instanceDecl <- deriveAction' tyCon dataCons
return [instanceDecl]
where
extractActionInfo :: Info -> Q (Type, [Con])
extractActionInfo (TyConI (DataD _ actionName _ _ cons _))
= return (ConT actionName, cons)
extractActionInfo _
= fail "deriveAction: expected type constructor"
deriveAction' :: Type -> [Con] -> Q Dec
deriveAction' tyCon dataCons = do
eqActionDec <- deriveEqAction dataCons
let instanceHead = ConT ''Action `AppT` tyCon
return $ InstanceD Nothing [] instanceHead [eqActionDec]
where
deriveEqAction :: [Con] -> Q Dec
deriveEqAction cons = do
clauses <- traverse deriveEqActionCase cons
let fallthroughClause = Clause [WildP, WildP] (NormalB (ConE 'Nothing)) []
clauses' = if length clauses > 1 then clauses ++ [fallthroughClause] else clauses
return $ FunD 'eqAction clauses'
deriveEqActionCase :: Con -> Q Clause
deriveEqActionCase con = do
binderNames <- replicateM (conNumArgs con) ((,) <$> newName "x" <*> newName "y")
let name = conName con
fstPat = ConP name (map (VarP . fst) binderNames)
sndPat = ConP name (map (VarP . snd) binderNames)
mkPairwiseComparison x y = VarE '(==) `AppE` VarE x `AppE` VarE y
pairwiseComparisons = map (uncurry mkPairwiseComparison) binderNames
bothComparisons x y = VarE '(&&) `AppE` x `AppE` y
allComparisons = foldr bothComparisons (ConE 'True) pairwiseComparisons
conditional = CondE allComparisons (ConE 'Just `AppE` ConE 'Refl) (ConE 'Nothing)
return $ Clause [fstPat, sndPat] (NormalB conditional) []
conName :: Con -> Name
conName (NormalC name _) = name
conName (RecC name _) = name
conName (InfixC _ name _) = name
conName (ForallC _ _ con) = conName con
conName (GadtC [name] _ _) = name
conName (GadtC names _ _) = error $ "conName: internal error; non-singleton GADT constructor names: " ++ show names
conName (RecGadtC [name] _ _) = name
conName (RecGadtC names _ _) = error $ "conName: internal error; non-singleton GADT record constructor names: " ++ show names
conNumArgs :: Con -> Int
conNumArgs (NormalC _ bts) = length bts
conNumArgs (RecC _ vbts) = length vbts
conNumArgs (InfixC _ _ _) = 2
conNumArgs (ForallC _ _ con) = conNumArgs con
conNumArgs (GadtC _ bts _) = length bts
conNumArgs (RecGadtC _ vbts _) = length vbts
unappliedType :: Type -> Type
unappliedType t@ConT{} = t
unappliedType (AppT t _) = unappliedType t
unappliedType other = error $ "unappliedType: internal error; expected plain applied type, given " ++ show other
unappliedTypeName :: Type -> Name
unappliedTypeName t = let (ConT name) = unappliedType t in name
typeArgs :: Type -> [Type]
typeArgs (AppT t a) = typeArgs t ++ [a]
typeArgs _ = []
splitFnType :: Type -> ([TyVarBndr], Cxt, [Type], Type)
splitFnType (a `AppT` b `AppT` c) | a == ArrowT =
let (tyVars, ctx, args, result) = splitFnType c
in (tyVars, ctx, b:args, result)
splitFnType (ForallT tyVars ctx a) =
let (tyVars', ctx', args, result) = splitFnType a
in (tyVars ++ tyVars', ctx ++ ctx', args, result)
splitFnType a = ([], [], [], a)
fnTypeArity :: Type -> Int
fnTypeArity t = let (_, _, args, _) = splitFnType t in length args
substituteTypeVar :: Name -> Type -> Type -> Type
substituteTypeVar initial replacement = doReplace
where doReplace (ForallT a b t) = ForallT a b (doReplace t)
doReplace (AppT a b) = AppT (doReplace a) (doReplace b)
doReplace (SigT t k) = SigT (doReplace t) k
doReplace t@(VarT n)
| n == initial = replacement
| otherwise = t
doReplace other = other
typeVarNames :: Type -> [Name]
typeVarNames (VarT n) = [n]
typeVarNames (AppT a b) = nub (typeVarNames a ++ typeVarNames b)
typeVarNames _ = []
tyVarBndrName :: TyVarBndr -> Name
tyVarBndrName (PlainTV name) = name
tyVarBndrName (KindedTV name _) = name
classMethods :: Info -> Q [Dec]
classMethods (ClassI (ClassD _ _ _ _ methods) _) = return $ removeDefaultSigs methods
where removeDefaultSigs = filter $ \case
DefaultSigD{} -> False
_ -> True
classMethods other = fail $ "classMethods: internal error; expected a class type, given " ++ show other
standaloneDeriveD' :: Cxt -> Type -> Dec
#if MIN_VERSION_template_haskell(2,12,0)
standaloneDeriveD' = StandaloneDerivD Nothing
#else
standaloneDeriveD' = StandaloneDerivD
#endif