module Test.Fixie.Internal.TH where
import qualified Control.Monad.Fail as Fail
import Control.Monad (join, replicateM, when, zipWithM)
import Test.Fixie.Internal (FixieT, Call(..), Function(..), unimplemented, captureCall, getFunction)
import Data.Char (isPunctuation, isSymbol)
import Data.Default.Class (Default(..))
import Data.List (foldl', nub, partition)
import Data.Text (pack)
import GHC.Exts (Constraint)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
mkFixture :: String -> [Type] -> Q [Dec]
mkFixture fixtureNameStr classTypes = do
let fixtureName = mkName fixtureNameStr
mapM_ assertDerivableConstraint classTypes
(fixtureDec, fixtureFields) <- mkFixtureRecord fixtureName classTypes
defaultInstanceDec <- mkDefaultInstance fixtureName fixtureFields
instanceDecs <- traverse (flip mkInstance fixtureName) classTypes
return ([fixtureDec, defaultInstanceDec] ++ instanceDecs)
mkFixtureRecord :: Name -> [Type] -> Q (Dec, [VarStrictType])
mkFixtureRecord fixtureName classTypes = do
let classNames = map unappliedTypeName classTypes
info <- traverse reify classNames
methods <- traverse classMethods info
mVar <- newName "m"
fixtureFields <- join <$> zipWithM (methodsToFields mVar) classTypes methods
let fixtureCs = [RecC fixtureName fixtureFields]
let mKind = AppT (AppT ArrowT StarT) StarT
let fixtureDec = mkDataD [] fixtureName [KindedTV mVar mKind] fixtureCs
return (fixtureDec, fixtureFields)
mkDefaultInstance :: Name -> [VarStrictType] -> Q Dec
mkDefaultInstance fixtureName fixtureFields = do
varName <- newName "m"
let appliedFixtureT = AppT (ConT fixtureName) (VarT varName)
let fieldNames = map (\(name, _, _) -> name) fixtureFields
let fixtureClauses = map unimplementedField fieldNames
let defImpl = RecConE fixtureName fixtureClauses
let defDecl = FunD 'def [Clause [] (NormalB defImpl) []]
return $ mkInstanceD [] (AppT (ConT ''Default) appliedFixtureT) [defDecl]
mkInstance :: Type -> Name -> Q Dec
mkInstance classType fixtureName = do
eVar <- VarT <$> newName "e"
mVar <- VarT <$> newName "m"
let fixtureWithoutVarsT = AppT (ConT ''FixieT) (ConT fixtureName)
let fixtureT = AppT (AppT fixtureWithoutVarsT eVar) mVar
let instanceHead = AppT classType fixtureT
classInfo <- reify (unappliedTypeName classType)
methods <- case classInfo of
ClassI (ClassD _ _ _ _ methods) _ -> return methods
_ -> fail $ "mkInstance: expected a class type, given " ++ show classType
funDecls <- traverse mkDictInstanceFunc methods
return $ mkInstanceD [AppT (ConT ''Monad) mVar] instanceHead funDecls
assertDerivableConstraint :: Type -> Q ()
assertDerivableConstraint classType = do
info <- reify $ unappliedTypeName classType
(ClassD _ _ classVars _ _) <- case info of
ClassI dec _ -> return dec
_ -> fail $ "mkFixture: expected a constraint, given ‘" ++ show (ppr classType) ++ "’"
let classArgs = typeArgs classType
let mkClassKind vars = foldr (\a b -> AppT (AppT ArrowT a) b) (ConT ''Constraint) (reverse varKinds)
where varKinds = map (\(KindedTV _ k) -> k) vars
constraintStr = show (ppr (ConT ''Constraint))
when (length classArgs > length classVars) $
fail $ "mkFixture: too many arguments for class\n"
++ " in: " ++ show (ppr classType) ++ "\n"
++ " for class of kind: " ++ show (ppr (mkClassKind classVars))
when (length classArgs == length classVars) $
fail $ "mkFixture: cannot derive instance for fully saturated constraint\n"
++ " in: " ++ show (ppr classType) ++ "\n"
++ " expected: * -> " ++ constraintStr ++ "\n"
++ " given: " ++ constraintStr
when (length classArgs < length classVars 1) $
fail $ "mkFixture: cannot derive instance for multi-parameter typeclass\n"
++ " in: " ++ show (ppr classType) ++ "\n"
++ " expected: * -> " ++ constraintStr ++ "\n"
++ " given: " ++ show (ppr (mkClassKind $ drop (length classArgs) classVars))
classMethods :: MonadFail m => Info -> m [Dec]
classMethods (ClassI (ClassD _ _ _ _ methods) _) = return methods
classMethods other = fail $ "classMethods: expected a class name, given " ++ show other
methodsToFields :: MonadFail m => Name -> Type -> [Dec] -> m [VarStrictType]
methodsToFields name typ = mapM (methodToField name typ)
methodToField :: MonadFail m => Name -> Type -> Dec -> m VarStrictType
methodToField mVar classT (SigD name typ) = (fieldName, noStrictness,) <$> newT
where fieldName = methodNameToFieldName name
newT = replaceClassConstraint classT mVar typ
methodToField _ _ _ = fail "methodToField: internal error; report a bug with the test-fixture package"
methodNameToFieldName :: Name -> Name
methodNameToFieldName name = mkName (prefixChar : nameBase name)
where isInfixChar c = (c `notElem` "_:\"'") && (isPunctuation c || isSymbol c)
nameIsInfix = isInfixChar . head $ nameBase name
prefixChar = if nameIsInfix then '~' else '_'
replaceClassConstraint :: MonadFail m => Type -> Name -> Type -> m Type
replaceClassConstraint classType freeVar (ForallT vars preds typ) =
let
unappliedClassType = unappliedType classType
classTypeArgs = typeArgs classType
([replacedPred], newPreds) = partition ((unappliedClassType ==) . unappliedType) preds
replacedVars = typeVarNames replacedPred
replacementTypes = classTypeArgs ++ [VarT freeVar]
newVars = filter ((`notElem` replacedVars) . tyVarBndrName) vars
replacedT = foldl' (flip $ uncurry substituteTypeVar) typ (zip replacedVars replacementTypes)
in return $ ForallT newVars newPreds replacedT
replaceClassConstraint _ _ _ = fail "replaceClassConstraint: internal error; report a bug with the test-fixture package"
substituteTypeVar :: Name -> Type -> Type -> Type
substituteTypeVar initial replacement = doReplace
where doReplace (ForallT a b t) = ForallT a b (doReplace t)
doReplace (AppT a b) = AppT (doReplace a) (doReplace b)
doReplace (SigT t k) = SigT (doReplace t) k
doReplace t@(VarT n)
| n == initial = replacement
| otherwise = t
doReplace other = other
unimplementedField :: Name -> FieldExp
unimplementedField fieldName = (fieldName, unimplementedE)
where unimplementedE = AppE (VarE 'unimplemented) (LitE (StringL $ nameBase fieldName))
mkDictInstanceFunc :: Dec -> Q Dec
mkDictInstanceFunc (SigD name typ) = do
let arity = functionTypeArity typ
argNames <- replicateM arity (newName "x")
let pats = map VarP argNames
let askFunc = VarE (methodNameToFieldName name)
let nameString = LitE (StringL (nameBase name))
let vars = map VarE argNames
implE <- [e|do
fn <- getFunction $(return askFunc)
let fnString = $(return nameString)
let call = Call $ Function (pack fnString)
captureCall call
$(return $ applyE (VarE 'fn) vars)
|]
let funClause = Clause pats (NormalB implE) []
return $ FunD name [funClause]
mkDictInstanceFunc other = fail $ "mkDictInstanceFunc: expected method signature, given " ++ show other
unappliedType :: Type -> Type
unappliedType t@ConT{} = t
unappliedType (AppT t _) = unappliedType t
unappliedType other = error $ "expected plain applied type, given " ++ show other
unappliedTypeName :: Type -> Name
unappliedTypeName t = let (ConT name) = unappliedType t in name
typeArgs :: Type -> [Type]
typeArgs (AppT t a) = typeArgs t ++ [a]
typeArgs _ = []
typeVarNames :: Type -> [Name]
typeVarNames (VarT n) = [n]
typeVarNames (AppT a b) = nub (typeVarNames a ++ typeVarNames b)
typeVarNames _ = []
tyVarBndrName :: TyVarBndr -> Name
tyVarBndrName (PlainTV name) = name
tyVarBndrName (KindedTV name _) = name
functionTypeArity :: Type -> Int
functionTypeArity (AppT (AppT ArrowT _) b) = 1 + functionTypeArity b
functionTypeArity (ForallT _ _ typ) = functionTypeArity typ
functionTypeArity _ = 0
applyE :: Exp -> [Exp] -> Exp
applyE = foldl' AppE
type MonadFail = Fail.MonadFail
mkInstanceD :: Cxt -> Type -> [Dec] -> Dec
mkInstanceD = InstanceD Nothing
mkDataD :: Cxt -> Name -> [TyVarBndr] -> [Con] -> Dec
mkDataD a b c d = DataD a b c Nothing d []
noStrictness :: Bang
noStrictness = Bang NoSourceUnpackedness NoSourceStrictness