{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE BlockArguments #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedRecordDot #-} module Test.MockCat.TH.FunctionBuilder ( createMockBuilderVerifyParams , createMockBuilderFnType , MockFnContext(..) , MockFnBuilder(..) , buildMockFnContext , buildMockFnDeclarations , determineMockFnBuilder , createNoInlinePragma , doCreateMockFnDecs , doCreateConstantMockFnDecs , doCreateEmptyVerifyParamMockFnDecs , createMockBody , createTypeablePreds , partialAdditionalPredicates , createFnName , findParam , typeToNames , safeIndex , generateInstanceMockFnBody , generateInstanceRealFnBody , generateStubFn ) where import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Trans.Class (lift) import Language.Haskell.TH ( Dec (..), Exp (..), Name, Pred, Q, Quote, Type (..), TyVarBndr(..), Inline (NoInline), RuleMatch (FunLike), Phases (AllPhases), mkName, newName ) import Language.Haskell.TH.Lib import Language.Haskell.TH.Syntax (nameBase, Specificity (SpecifiedSpec)) import Test.MockCat.Mock ( MockBuilder ) import qualified Test.MockCat.Internal.MockRegistry as Registry import Test.MockCat.Internal.Builder (buildMock) import Test.MockCat.Internal.Types (BuiltMock(..)) import Test.MockCat.Cons (Head(..), (:>)(..)) import Test.MockCat.MockT ( MockT (..), Definition (..), getDefinitions, addDefinition ) import Test.MockCat.TH.TypeUtils ( isNotConstantFunctionType, needsTypeable, collectTypeVars, collectTypeableTargets ) import Test.MockCat.TH.ContextBuilder ( MockType (..) ) import Test.MockCat.TH.ClassAnalysis ( VarAppliedType (..), updateType ) import Test.MockCat.Verify (ResolvableParamsOf) import Data.Dynamic (Dynamic, toDyn) import Data.Proxy (Proxy(..)) import Data.List (find, nubBy) import Data.Typeable (Typeable) import Language.Haskell.TH.Ppr (pprint) import Unsafe.Coerce (unsafeCoerce) import GHC.TypeLits (KnownSymbol, symbolVal) import Test.MockCat.Param (Param, param) import Test.MockCat.TH.Types (MockOptions(..)) createMockBuilderVerifyParams :: Type -> Type createMockBuilderVerifyParams (AppT (AppT ArrowT ty) (AppT (VarT _) _)) = AppT (ConT ''Param) ty createMockBuilderVerifyParams (AppT (AppT ArrowT ty) ty2) = AppT (AppT (ConT ''(:>)) (AppT (ConT ''Param) ty)) (createMockBuilderVerifyParams ty2) createMockBuilderVerifyParams (AppT (VarT _) _) = TupleT 0 createMockBuilderVerifyParams (AppT (ConT _) _) = TupleT 0 createMockBuilderVerifyParams (ForallT _ _ ty) = createMockBuilderVerifyParams ty createMockBuilderVerifyParams (VarT _) = TupleT 0 createMockBuilderVerifyParams (ConT _) = TupleT 0 createMockBuilderVerifyParams _ = TupleT 0 createMockBuilderFnType :: Name -> Type -> Type createMockBuilderFnType monadVarName a@(AppT (VarT var) ty) | monadVarName == var = ty | otherwise = a createMockBuilderFnType monadVarName (AppT ty ty2) = AppT ty (createMockBuilderFnType monadVarName ty2) createMockBuilderFnType monadVarName (ForallT _ _ ty) = createMockBuilderFnType monadVarName ty createMockBuilderFnType _ ty = ty partialAdditionalPredicates :: Type -> Type -> [Pred] partialAdditionalPredicates funType verifyParams = [ AppT (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) funType)) verifyParams | not (null (collectTypeVars funType)) ] -- Helper to create Typeable predicates using the smart collection logic createTypeablePreds :: [Type] -> [Pred] createTypeablePreds targets = [ AppT (ConT ''Typeable) t | t <- nubBy (\a b -> pprint a == pprint b) (concatMap collectTypeableTargets targets) , needsTypeable t ] data MockFnContext = MockFnContext { mockType :: MockType, monadVarName :: Name, mockOptions :: MockOptions, originalType :: Type, fnNameStr :: String, mockFnName :: Name, paramsName :: Name, updatedType :: Type, fnType :: Type } data MockFnBuilder = VariadicBuilder | ConstantImplicitBuilder | ConstantExplicitBuilder buildMockFnContext :: MockType -> Name -> [VarAppliedType] -> MockOptions -> Name -> Type -> MockFnContext buildMockFnContext mockType monadVarName varAppliedTypes mockOptions sigFnName ty = let fnNameStr = createFnName sigFnName mockOptions mockFnName = mkName fnNameStr params = mkName "p" updatedType = updateType ty varAppliedTypes fnType = if mockOptions.implicitMonadicReturn then createMockBuilderFnType monadVarName updatedType else updatedType in MockFnContext { mockType, monadVarName, mockOptions, originalType = ty, fnNameStr, mockFnName, paramsName = params, updatedType, fnType } buildMockFnDeclarations :: MockFnContext -> Q [Dec] buildMockFnDeclarations ctx@MockFnContext{mockType, fnNameStr, mockFnName, paramsName, fnType, monadVarName, updatedType} = case determineMockFnBuilder ctx of VariadicBuilder -> doCreateMockFnDecs mockType fnNameStr mockFnName paramsName fnType monadVarName updatedType ConstantImplicitBuilder -> doCreateConstantMockFnDecs mockType fnNameStr mockFnName fnType monadVarName ConstantExplicitBuilder -> doCreateEmptyVerifyParamMockFnDecs fnNameStr mockFnName paramsName fnType monadVarName updatedType determineMockFnBuilder :: MockFnContext -> MockFnBuilder determineMockFnBuilder ctx | isNotConstantFunctionType (originalType ctx) = VariadicBuilder | (mockOptions ctx).implicitMonadicReturn = ConstantImplicitBuilder | otherwise = ConstantExplicitBuilder createNoInlinePragma :: Name -> Q Dec createNoInlinePragma name = pragInlD name NoInline FunLike AllPhases doCreateMockFnDecs :: (Quote m) => MockType -> String -> Name -> Name -> Type -> Name -> Type -> m [Dec] doCreateMockFnDecs mockType funNameStr mockFunName params funType monadVarName updatedType = do newFunSig <- do let verifyParams = createMockBuilderVerifyParams updatedType mockBuilderPred = AppT (AppT (AppT (ConT ''MockBuilder) (VarT params)) funType) verifyParams eqConstraint = [ AppT (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) funType)) verifyParams | not (null (collectTypeVars funType)) ] baseCtx = ([mockBuilderPred | verifyParams /= TupleT 0]) ++ [AppT (ConT ''MonadIO) (VarT monadVarName)] typeablePreds = createTypeablePreds [funType, verifyParams] ctx = case mockType of Partial -> baseCtx ++ partialAdditionalPredicates funType verifyParams ++ typeablePreds Total -> baseCtx ++ eqConstraint ++ typeablePreds resultType = AppT (AppT ArrowT (VarT params)) (AppT (AppT (ConT ''MockT) (VarT monadVarName)) funType) sigD mockFunName (pure (ForallT [] ctx resultType)) mockBody <- createMockBody funNameStr [|p|] funType newFun <- funD mockFunName [clause [varP $ mkName "p"] (normalB (pure mockBody)) []] pure $ newFunSig : [newFun] doCreateConstantMockFnDecs :: (Quote m) => MockType -> String -> Name -> Type -> Name -> m [Dec] doCreateConstantMockFnDecs Partial funNameStr mockFunName _ monadVarName = do stubVar <- newName "r" let ctx = [ AppT (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) (VarT stubVar))) (TupleT 0) , AppT (ConT ''MonadIO) (VarT monadVarName) , AppT (ConT ''Typeable) (VarT stubVar) , AppT (ConT ''Show) (VarT stubVar) , AppT (ConT ''Eq) (VarT stubVar) ] resultType = AppT (AppT ArrowT (VarT stubVar)) (AppT (AppT (ConT ''MockT) (VarT monadVarName)) (VarT stubVar)) newFunSig <- sigD mockFunName ( pure (ForallT [ PlainTV stubVar SpecifiedSpec , PlainTV monadVarName SpecifiedSpec ] ctx resultType ) ) headParam <- [|Head :> param p|] mockBody <- createMockBody funNameStr (pure headParam) (VarT stubVar) newFun <- funD mockFunName [clause [varP $ mkName "p"] (normalB (pure mockBody)) []] pure $ newFunSig : [newFun] doCreateConstantMockFnDecs Total funNameStr mockFunName ty monadVarName = do (newFunSig, funTypeForBody) <- case ty of AppT (ConT _) (VarT mv) | mv == monadVarName -> do a <- newName "a" let ctx = [ AppT (ConT ''MonadIO) (VarT monadVarName) , AppT (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) (VarT a))) (TupleT 0) , AppT (ConT ''Typeable) (VarT a) , AppT (ConT ''Show) (VarT a) , AppT (ConT ''Eq) (VarT a) ] resultType = AppT (AppT ArrowT (VarT a)) (AppT (AppT (ConT ''MockT) (VarT monadVarName)) (VarT a)) sig <- sigD mockFunName ( pure (ForallT [PlainTV a SpecifiedSpec, PlainTV monadVarName SpecifiedSpec] ctx resultType ) ) pure (sig, VarT a) _ -> do let headParamType = AppT (AppT (ConT ''(:>)) (ConT ''Head)) (AppT (ConT ''Param) ty) verifyParams' = createMockBuilderVerifyParams ty mockBuilderPred' = AppT (AppT (AppT (ConT ''MockBuilder) headParamType) ty) (TupleT 0) ctx = [ AppT (ConT ''MonadIO) (VarT monadVarName) ] ++ ([mockBuilderPred' | verifyParams' /= TupleT 0]) ++ createTypeablePreds [ty] resultType = AppT (AppT ArrowT ty) (AppT (AppT (ConT ''MockT) (VarT monadVarName)) ty) sig <- sigD mockFunName (pure (ForallT [PlainTV monadVarName SpecifiedSpec] ctx resultType)) pure (sig, ty) headParam <- [|Head :> param p|] mockBody <- createMockBody funNameStr (pure headParam) funTypeForBody newFun <- funD mockFunName [clause [varP $ mkName "p"] (normalB (pure mockBody)) []] pure $ newFunSig : [newFun] doCreateEmptyVerifyParamMockFnDecs :: (Quote m) => String -> Name -> Name -> Type -> Name -> Type -> m [Dec] doCreateEmptyVerifyParamMockFnDecs funNameStr mockFunName params funType monadVarName updatedType = do newFunSig <- do let verifyParams = createMockBuilderVerifyParams updatedType mockBuilderPred = AppT (AppT (AppT (ConT ''MockBuilder) (VarT params)) funType) verifyParams eqConstraint = [ AppT (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) funType)) verifyParams | not (null (collectTypeVars funType)) ] ctx = [mockBuilderPred] ++ [AppT (ConT ''MonadIO) (VarT monadVarName)] ++ eqConstraint ++ createTypeablePreds [funType, verifyParams] resultType = AppT (AppT ArrowT (VarT params)) (AppT (AppT (ConT ''MockT) (VarT monadVarName)) funType) sigD mockFunName (pure (ForallT [] ctx resultType)) mockBody <- createMockBody funNameStr [|p|] funType newFun <- funD mockFunName [clause [varP $ mkName "p"] (normalB (pure mockBody)) []] pure $ newFunSig : [newFun] createMockBody :: (Quote m) => String -> m Exp -> Type -> m Exp createMockBody funNameStr paramsExp _funType = do params <- paramsExp [| MockT $ do -- Build the mock instance and its verifier directly so we have access -- to the verifier value (avoids runtime type-mismatch when resolving). BuiltMock { builtMockFn = mockInstance, builtMockRecorder = verifier } <- liftIO $ buildMock (Just $(litE (stringL funNameStr))) $(pure params) -- Register and get the canonical wrapper (preserved for async safety) canonicalInstance <- liftIO $ Registry.register (Just $(litE (stringL funNameStr))) verifier mockInstance addDefinition ( Definition (Proxy :: Proxy $(litT (strTyLit funNameStr))) canonicalInstance NoVerification ) pure canonicalInstance |] createFnName :: Name -> MockOptions -> String createFnName funName opts = do opts.prefix <> nameBase funName <> opts.suffix findParam :: KnownSymbol sym => Proxy sym -> [Definition] -> Maybe Dynamic findParam pa definitions = do let definition = find (\(Definition s _ _) -> symbolVal s == symbolVal pa) definitions fmap (\(Definition _ mockFunction _) -> toDyn mockFunction) definition typeToNames :: Type -> [Q Name] typeToNames (AppT (AppT ArrowT _) t2) = newName "a" : typeToNames t2 typeToNames (ForallT _ _ ty) = typeToNames ty typeToNames _ = [] safeIndex :: [a] -> Int -> Maybe a safeIndex [] _ = Nothing safeIndex (x : _) 0 = Just x safeIndex (_ : xs) n | n < 0 = Nothing | otherwise = safeIndex xs (n - 1) generateInstanceMockFnBody :: String -> [Q Exp] -> Name -> MockOptions -> Q Exp generateInstanceMockFnBody fnNameStr args r opts = do returnExp <- if opts.implicitMonadicReturn then [| pure $(varE r) |] else [| lift $(varE r) |] [| MockT $ do defs <- getDefinitions let findDef = find (\(Definition s _ _) -> symbolVal s == $(litE (stringL fnNameStr))) defs case findDef of Just (Definition _ mf _) -> do let mock = unsafeCoerce mf let $(bangP $ varP r) = $(generateStubFn args [|mock|]) $(pure returnExp) Nothing -> error $ "no answer found stub function `" ++ fnNameStr ++ "`." |] generateInstanceRealFnBody :: Name -> String -> [Q Exp] -> Name -> MockOptions -> Q Exp generateInstanceRealFnBody fnName fnNameStr args r opts = do returnExp <- if opts.implicitMonadicReturn then [| pure $(varE r) |] else [| lift $(varE r) |] [| MockT $ do defs <- getDefinitions let findDef = find (\(Definition s _ _) -> symbolVal s == $(litE (stringL fnNameStr))) defs case findDef of Just (Definition _ mf _) -> do let mock = unsafeCoerce mf let $(bangP $ varP r) = $(generateStubFn args [|mock|]) $(pure returnExp) Nothing -> lift $ $(foldl appE (varE fnName) args) |] generateStubFn :: [Q Exp] -> Q Exp -> Q Exp generateStubFn [] mock = mock generateStubFn args mock = foldl appE mock args