module Control.Monad.TestFixture.TH.Internal where
#if MIN_VERSION_base(4,9,0)
import qualified Control.Monad.Fail as Fail
#endif
import qualified Control.Monad.Reader as Reader
import Prelude hiding (log)
import Control.Monad (join, replicateM, zipWithM)
import Control.Monad.TestFixture (TestFixture, TestFixtureT, unimplemented)
import Data.Char (isPunctuation, isSymbol)
import Data.Default (Default(..))
import Data.List (foldl', nub, partition)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
mkFixture :: String -> [Name] -> Q [Dec]
mkFixture fixtureNameStr classNames = do
let fixtureName = mkName fixtureNameStr
(fixtureDec, fixtureFields) <- mkFixtureRecord fixtureName classNames
typeSynonyms <- mkFixtureTypeSynonyms fixtureName
defaultInstanceDec <- mkDefaultInstance fixtureName fixtureFields
infos <- traverse reify classNames
instanceDecs <- traverse (flip mkInstance fixtureName) infos
return ([fixtureDec, defaultInstanceDec] ++ typeSynonyms ++ instanceDecs)
mkFixtureRecord :: Name -> [Name] -> Q (Dec, [VarStrictType])
mkFixtureRecord fixtureName classNames = do
types <- traverse conT classNames
info <- traverse reify classNames
methods <- traverse classMethods info
mVar <- newName "m"
fixtureFields <- join <$> zipWithM (methodsToFields mVar) types methods
let fixtureCs = [RecC fixtureName fixtureFields]
let fixtureDec = mkDataD [] fixtureName [PlainTV mVar] fixtureCs
return (fixtureDec, fixtureFields)
mkFixtureTypeSynonyms :: Name -> Q [Dec]
mkFixtureTypeSynonyms fixtureName = do
mName <- newName "m"
logName <- newName "log"
stateName <- newName "state"
let mVar = VarT mName
let logVar = VarT logName
let stateVar = VarT stateName
let mTVBndr = PlainTV mName
let logTVBndr = PlainTV logName
let stateTVBndr = PlainTV stateName
let fixturePure = mkTypeSynonym "Pure" [] (mkFixtureType unit unit)
let fixtureLog = mkTypeSynonym "Log" [logTVBndr] (mkFixtureType logVar unit)
let fixtureState = mkTypeSynonym "State" [stateTVBndr] (mkFixtureType unit stateVar)
let fixtureLogState = mkTypeSynonym "LogState" [logTVBndr, stateTVBndr] (mkFixtureType logVar stateVar)
let fixturePureT = mkTypeSynonym "PureT" [mTVBndr] (mkFixtureTransformerType unit unit mVar)
let fixtureLogT = mkTypeSynonym "LogT" [logTVBndr, mTVBndr] (mkFixtureTransformerType logVar unit mVar)
let fixtureStateT = mkTypeSynonym "StateT" [stateTVBndr, mTVBndr] (mkFixtureTransformerType unit stateVar mVar)
let fixtureLogStateT = mkTypeSynonym "LogStateT" [logTVBndr, stateTVBndr, mTVBndr] (mkFixtureTransformerType logVar stateVar mVar)
return
[ fixturePure
, fixtureLog
, fixtureState
, fixtureLogState
, fixturePureT
, fixtureLogT
, fixtureStateT
, fixtureLogStateT
]
where
unit = TupleT 0
mkTypeSynonym suffix varBndr ty = TySynD (mkName (nameBase fixtureName ++ suffix)) varBndr ty
mkFixtureType log state = AppT (ConT fixtureName) (AppT (AppT (AppT (ConT ''TestFixture) (ConT fixtureName)) log) state)
mkFixtureTransformerType log state m = AppT (ConT fixtureName) (AppT (AppT (AppT (AppT (ConT ''TestFixtureT) (ConT fixtureName)) log) state) m)
mkDefaultInstance :: Name -> [VarStrictType] -> Q Dec
mkDefaultInstance fixtureName fixtureFields = do
varName <- newName "m"
let appliedFixtureT = AppT (ConT fixtureName) (VarT varName)
let fieldNames = map (\(name, _, _) -> name) fixtureFields
let fixtureClauses = map unimplementedField fieldNames
let defImpl = RecConE fixtureName fixtureClauses
let defDecl = FunD 'def [Clause [] (NormalB defImpl) []]
return $ mkInstanceD [] (AppT (ConT ''Default) appliedFixtureT) [defDecl]
mkInstance :: Info -> Name -> Q Dec
mkInstance (ClassI (ClassD _ className _ _ methods) _) fixtureName = do
writerVar <- VarT <$> newName "w"
stateVar <- VarT <$> newName "s"
mVar <- VarT <$> newName "m"
let fixtureWithoutVarsT = AppT (ConT ''TestFixtureT) (ConT fixtureName)
let fixtureT = AppT (AppT (AppT fixtureWithoutVarsT writerVar) stateVar) mVar
let instanceHead = AppT (ConT className) fixtureT
funDecls <- traverse mkDictInstanceFunc methods
return $ mkInstanceD [AppT (ConT ''Monad) mVar] instanceHead funDecls
mkInstance other _ = fail $ "mkInstance: expected a class name, given " ++ show other
classMethods :: MonadFail m => Info -> m [Dec]
classMethods (ClassI (ClassD _ _ _ _ methods) _) = return methods
classMethods other = fail $ "classMethods: expected a class name, given " ++ show other
methodsToFields :: MonadFail m => Name -> Type -> [Dec] -> m [VarStrictType]
methodsToFields name typ = mapM (methodToField name typ)
methodToField :: MonadFail m => Name -> Type -> Dec -> m VarStrictType
methodToField mVar classT (SigD name typ) = (fieldName, noStrictness,) <$> newT
where fieldName = methodNameToFieldName name
newT = replaceClassConstraint classT mVar typ
methodToField _ _ _ = fail "methodToField: internal error; report a bug with the test-fixture package"
methodNameToFieldName :: Name -> Name
methodNameToFieldName name = mkName (prefixChar : nameBase name)
where isInfixChar c = (c `notElem` "_:\"'") && (isPunctuation c || isSymbol c)
nameIsInfix = isInfixChar . head $ nameBase name
prefixChar = if nameIsInfix then '~' else '_'
replaceClassConstraint :: MonadFail m => Type -> Name -> Type -> m Type
replaceClassConstraint constraint freeVar (ForallT vars preds typ) = do
let (newPreds, [replacedPred]) = partition ((constraint /=) . unappliedType) preds
replacedVar <- case typeVarNames replacedPred of
[singleVar] -> return singleVar
_ -> fail "generating instances of multi-parameter typeclasses is currently unsupported"
let newVars = filter ((replacedVar /=) . tyVarBndrName) vars
replacedT = replaceTypeVarName replacedVar freeVar typ
return $ ForallT newVars newPreds replacedT
replaceClassConstraint _ _ _ = fail "replaceClassConstraint: internal error; report a bug with the test-fixture package"
replaceTypeVarName :: Name -> Name -> Type -> Type
replaceTypeVarName initial replacement = doReplace
where doReplace (ForallT a b t) = ForallT a b (doReplace t)
doReplace (AppT a b) = AppT (doReplace a) (doReplace b)
doReplace (SigT t k) = SigT (doReplace t) k
doReplace (VarT n)
| n == initial = VarT replacement
| otherwise = VarT n
doReplace other = other
unimplementedField :: Name -> FieldExp
unimplementedField fieldName = (fieldName, unimplementedE)
where unimplementedE = AppE (VarE 'unimplemented) (LitE (StringL $ nameBase fieldName))
mkDictInstanceFunc :: Dec -> Q Dec
mkDictInstanceFunc (SigD name typ) = do
let arity = functionTypeArity typ
argNames <- replicateM arity (newName "x")
let pats = map VarP argNames
let askFunc = VarE (methodNameToFieldName name)
let vars = map VarE argNames
implE <- [e|do
fn <- Reader.asks $(return askFunc)
$(return $ applyE (VarE 'fn) vars)|]
let funClause = Clause pats (NormalB implE) []
return $ FunD name [funClause]
mkDictInstanceFunc other = fail $ "mkDictInstanceFunc: expected method signature, given " ++ show other
unappliedType :: Type -> Type
unappliedType t@ConT{} = t
unappliedType (AppT t _) = unappliedType t
unappliedType other = error $ "expected plain applied type, given " ++ show other
typeVarNames :: Type -> [Name]
typeVarNames (VarT n) = [n]
typeVarNames (AppT a b) = nub (typeVarNames a ++ typeVarNames b)
typeVarNames _ = []
tyVarBndrName :: TyVarBndr -> Name
tyVarBndrName (PlainTV name) = name
tyVarBndrName (KindedTV name _) = name
functionTypeArity :: Type -> Int
functionTypeArity (AppT (AppT ArrowT _) b) = 1 + functionTypeArity b
functionTypeArity (ForallT _ _ typ) = functionTypeArity typ
functionTypeArity _ = 0
applyE :: Exp -> [Exp] -> Exp
applyE = foldl' AppE
#if MIN_VERSION_base(4,9,0)
type MonadFail = Fail.MonadFail
#else
type MonadFail = Monad
#endif
mkInstanceD :: Cxt -> Type -> [Dec] -> Dec
#if MIN_VERSION_template_haskell(2,11,0)
mkInstanceD = InstanceD Nothing
#else
mkInstanceD = InstanceD
#endif
mkDataD :: Cxt -> Name -> [TyVarBndr] -> [Con] -> Dec
#if MIN_VERSION_template_haskell(2,11,0)
mkDataD a b c d = DataD a b c Nothing d []
#else
mkDataD a b c d = DataD a b c d []
#endif
#if MIN_VERSION_template_haskell(2,11,0)
noStrictness :: Bang
noStrictness = Bang NoSourceUnpackedness NoSourceStrictness
#else
noStrictness :: Strict
noStrictness = NotStrict
#endif