module Control.Monad.TestFixture.TH.Internal where
#if MIN_VERSION_base(4,9,0)
import qualified Control.Monad.Fail as Fail
#endif
import qualified Control.Monad.Reader as Reader
import Prelude hiding (log)
import Control.Monad (join, replicateM, when, zipWithM)
import Control.Monad.TestFixture (TestFixture, TestFixtureT, unimplemented)
import Data.Char (isPunctuation, isSymbol)
import Data.Default.Class (Default(..))
import Data.List (foldl', nub, partition)
import GHC.Exts (Constraint)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
mkFixture :: String -> [Type] -> Q [Dec]
mkFixture fixtureNameStr classTypes = do
let fixtureName = mkName fixtureNameStr
mapM_ assertDerivableConstraint classTypes
(fixtureDec, fixtureFields) <- mkFixtureRecord fixtureName classTypes
typeSynonyms <- mkFixtureTypeSynonyms fixtureName
defaultInstanceDec <- mkDefaultInstance fixtureName fixtureFields
instanceDecs <- traverse (flip mkInstance fixtureName) classTypes
return ([fixtureDec, defaultInstanceDec] ++ typeSynonyms ++ instanceDecs)
mkFixtureRecord :: Name -> [Type] -> Q (Dec, [VarStrictType])
mkFixtureRecord fixtureName classTypes = do
let classNames = map unappliedTypeName classTypes
info <- traverse reify classNames
methods <- traverse classMethods info
mVar <- newName "m"
fixtureFields <- join <$> zipWithM (methodsToFields mVar) classTypes methods
let fixtureCs = [RecC fixtureName fixtureFields]
let mKind = AppT (AppT ArrowT StarT) StarT
let fixtureDec = mkDataD [] fixtureName [KindedTV mVar mKind] fixtureCs
return (fixtureDec, fixtureFields)
mkFixtureTypeSynonyms :: Name -> Q [Dec]
mkFixtureTypeSynonyms fixtureName = do
mName <- newName "m"
logName <- newName "log"
stateName <- newName "state"
let mVar = VarT mName
let logVar = VarT logName
let stateVar = VarT stateName
let mTVBndr = PlainTV mName
let logTVBndr = PlainTV logName
let stateTVBndr = PlainTV stateName
let fixturePure = mkTypeSynonym "Pure" [] (mkFixtureType unit unit)
let fixtureLog = mkTypeSynonym "Log" [logTVBndr] (mkFixtureType logVar unit)
let fixtureState = mkTypeSynonym "State" [stateTVBndr] (mkFixtureType unit stateVar)
let fixtureLogState = mkTypeSynonym "LogState" [logTVBndr, stateTVBndr] (mkFixtureType logVar stateVar)
let fixturePureT = mkTypeSynonym "PureT" [mTVBndr] (mkFixtureTransformerType unit unit mVar)
let fixtureLogT = mkTypeSynonym "LogT" [logTVBndr, mTVBndr] (mkFixtureTransformerType logVar unit mVar)
let fixtureStateT = mkTypeSynonym "StateT" [stateTVBndr, mTVBndr] (mkFixtureTransformerType unit stateVar mVar)
let fixtureLogStateT = mkTypeSynonym "LogStateT" [logTVBndr, stateTVBndr, mTVBndr] (mkFixtureTransformerType logVar stateVar mVar)
return
[ fixturePure
, fixtureLog
, fixtureState
, fixtureLogState
, fixturePureT
, fixtureLogT
, fixtureStateT
, fixtureLogStateT
]
where
unit = TupleT 0
mkTypeSynonym suffix varBndr ty = TySynD (mkName (nameBase fixtureName ++ suffix)) varBndr ty
mkFixtureType log state = AppT (ConT fixtureName) (AppT (AppT (AppT (ConT ''TestFixture) (ConT fixtureName)) log) state)
mkFixtureTransformerType log state m = AppT (ConT fixtureName) (AppT (AppT (AppT (AppT (ConT ''TestFixtureT) (ConT fixtureName)) log) state) m)
mkDefaultInstance :: Name -> [VarStrictType] -> Q Dec
mkDefaultInstance fixtureName fixtureFields = do
varName <- newName "m"
let appliedFixtureT = AppT (ConT fixtureName) (VarT varName)
let fieldNames = map (\(name, _, _) -> name) fixtureFields
let fixtureClauses = map unimplementedField fieldNames
let defImpl = RecConE fixtureName fixtureClauses
let defDecl = FunD 'def [Clause [] (NormalB defImpl) []]
return $ mkInstanceD [] (AppT (ConT ''Default) appliedFixtureT) [defDecl]
mkInstance :: Type -> Name -> Q Dec
mkInstance classType fixtureName = do
writerVar <- VarT <$> newName "w"
stateVar <- VarT <$> newName "s"
mVar <- VarT <$> newName "m"
let fixtureWithoutVarsT = AppT (ConT ''TestFixtureT) (ConT fixtureName)
let fixtureT = AppT (AppT (AppT fixtureWithoutVarsT writerVar) stateVar) mVar
let instanceHead = AppT classType fixtureT
classInfo <- reify (unappliedTypeName classType)
methods <- classMethods classInfo
funDecls <- traverse mkDictInstanceFunc methods
return $ mkInstanceD [AppT (ConT ''Monad) mVar] instanceHead funDecls
assertDerivableConstraint :: Type -> Q ()
assertDerivableConstraint classType = do
info <- reify $ unappliedTypeName classType
(ClassD _ _ classVars _ _) <- case info of
ClassI dec _ -> return dec
_ -> fail $ "mkFixture: expected a constraint, given ‘" ++ show (ppr classType) ++ "’"
let classArgs = typeArgs classType
let mkClassKind vars = foldr (\a b -> AppT (AppT ArrowT a) b) (ConT ''Constraint) (reverse varKinds)
where varKinds = map (\(KindedTV _ k) -> k) vars
constraintStr = show (ppr (ConT ''Constraint))
when (length classArgs > length classVars) $
fail $ "mkFixture: too many arguments for class\n"
++ " in: " ++ show (ppr classType) ++ "\n"
++ " for class of kind: " ++ show (ppr (mkClassKind classVars))
when (length classArgs == length classVars) $
fail $ "mkFixture: cannot derive instance for fully saturated constraint\n"
++ " in: " ++ show (ppr classType) ++ "\n"
++ " expected: * -> " ++ constraintStr ++ "\n"
++ " given: " ++ constraintStr
when (length classArgs < length classVars 1) $
fail $ "mkFixture: cannot derive instance for multi-parameter typeclass\n"
++ " in: " ++ show (ppr classType) ++ "\n"
++ " expected: * -> " ++ constraintStr ++ "\n"
++ " given: " ++ show (ppr (mkClassKind $ drop (length classArgs) classVars))
classMethods :: MonadFail m => Info -> m [Dec]
classMethods (ClassI (ClassD _ _ _ _ methods) _) = return $ removeDefaultSigs methods
where removeDefaultSigs = filter $ \case
DefaultSigD{} -> False
_ -> True
classMethods other = fail $ "classMethods: expected a class type, given " ++ show other
methodsToFields :: MonadFail m => Name -> Type -> [Dec] -> m [VarStrictType]
methodsToFields name typ = mapM (methodToField name typ)
methodToField :: MonadFail m => Name -> Type -> Dec -> m VarStrictType
methodToField mVar classT (SigD name typ) = (fieldName, noStrictness,) <$> newT
where fieldName = methodNameToFieldName name
newT = replaceClassConstraint classT mVar typ
methodToField _ _ _ = fail "methodToField: internal error; report a bug with the test-fixture package"
methodNameToFieldName :: Name -> Name
methodNameToFieldName name = mkName (prefixChar : nameBase name)
where isInfixChar c = (c `notElem` "_:\"'") && (isPunctuation c || isSymbol c)
nameIsInfix = isInfixChar . head $ nameBase name
prefixChar = if nameIsInfix then '~' else '_'
replaceClassConstraint :: MonadFail m => Type -> Name -> Type -> m Type
replaceClassConstraint classType freeVar (ForallT vars preds typ) =
let
unappliedClassType = unappliedType classType
classTypeArgs = typeArgs classType
([replacedPred], newPreds) = partition ((unappliedClassType ==) . unappliedType) preds
replacedVars = typeVarNames replacedPred
replacementTypes = classTypeArgs ++ [VarT freeVar]
newVars = filter ((`notElem` replacedVars) . tyVarBndrName) vars
replacedT = foldl' (flip $ uncurry substituteTypeVar) typ (zip replacedVars replacementTypes)
in return $ ForallT newVars newPreds replacedT
replaceClassConstraint _ _ _ = fail "replaceClassConstraint: internal error; report a bug with the test-fixture package"
substituteTypeVar :: Name -> Type -> Type -> Type
substituteTypeVar initial replacement = doReplace
where doReplace (ForallT a b t) = ForallT a b (doReplace t)
doReplace (AppT a b) = AppT (doReplace a) (doReplace b)
doReplace (SigT t k) = SigT (doReplace t) k
doReplace t@(VarT n)
| n == initial = replacement
| otherwise = t
doReplace other = other
unimplementedField :: Name -> FieldExp
unimplementedField fieldName = (fieldName, unimplementedE)
where unimplementedE = AppE (VarE 'unimplemented) (LitE (StringL $ nameBase fieldName))
mkDictInstanceFunc :: Dec -> Q Dec
mkDictInstanceFunc (SigD name typ) = do
let arity = functionTypeArity typ
argNames <- replicateM arity (newName "x")
let pats = map VarP argNames
let askFunc = VarE (methodNameToFieldName name)
let vars = map VarE argNames
implE <- [e|do
fn <- Reader.asks $(return askFunc)
$(return $ applyE (VarE 'fn) vars)|]
let funClause = Clause pats (NormalB implE) []
return $ FunD name [funClause]
mkDictInstanceFunc other = fail $ "mkDictInstanceFunc: expected method signature, given " ++ show other
unappliedType :: Type -> Type
unappliedType t@ConT{} = t
unappliedType (AppT t _) = unappliedType t
unappliedType other = error $ "expected plain applied type, given " ++ show other
unappliedTypeName :: Type -> Name
unappliedTypeName t = let (ConT name) = unappliedType t in name
typeArgs :: Type -> [Type]
typeArgs (AppT t a) = typeArgs t ++ [a]
typeArgs _ = []
typeVarNames :: Type -> [Name]
typeVarNames (VarT n) = [n]
typeVarNames (AppT a b) = nub (typeVarNames a ++ typeVarNames b)
typeVarNames _ = []
tyVarBndrName :: TyVarBndr -> Name
tyVarBndrName (PlainTV name) = name
tyVarBndrName (KindedTV name _) = name
functionTypeArity :: Type -> Int
functionTypeArity (AppT (AppT ArrowT _) b) = 1 + functionTypeArity b
functionTypeArity (ForallT _ _ typ) = functionTypeArity typ
functionTypeArity _ = 0
applyE :: Exp -> [Exp] -> Exp
applyE = foldl' AppE
#if MIN_VERSION_base(4,9,0)
type MonadFail = Fail.MonadFail
#else
type MonadFail = Monad
#endif
mkInstanceD :: Cxt -> Type -> [Dec] -> Dec
#if MIN_VERSION_template_haskell(2,11,0)
mkInstanceD = InstanceD Nothing
#else
mkInstanceD = InstanceD
#endif
mkDataD :: Cxt -> Name -> [TyVarBndr] -> [Con] -> Dec
#if MIN_VERSION_template_haskell(2,11,0)
mkDataD a b c d = DataD a b c Nothing d []
#else
mkDataD a b c d = DataD a b c d []
#endif
#if MIN_VERSION_template_haskell(2,11,0)
noStrictness :: Bang
noStrictness = Bang NoSourceUnpackedness NoSourceStrictness
#else
noStrictness :: Strict
noStrictness = NotStrict
#endif