{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ConstraintKinds #-} {-# OPTIONS_GHC -Wno-missing-export-lists #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# HLINT ignore "Use null" #-} {-# OPTIONS_GHC -Wno-name-shadowing #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE GADTs #-} {-# OPTIONS_GHC -Wno-redundant-constraints #-} module Test.MockCat.Verify where import Control.Concurrent.STM (TVar, readTVarIO) import Test.MockCat.Internal.Types import Control.Monad (guard, when, unless) import Data.List (elemIndex, intercalate) import Data.Maybe import Test.MockCat.Param import Prelude hiding (lookup) import GHC.Stack (HasCallStack) import Test.MockCat.Internal.Message import Data.Kind (Type, Constraint) import Test.MockCat.Cons ((:>)) import Data.Typeable (Typeable, eqT) import Test.MockCat.Internal.MockRegistry (lookupVerifierForFn, withAllUnitGuards) import Data.Type.Equality ((:~:) (Refl)) import Data.Dynamic (fromDynamic) import GHC.TypeLits (TypeError, ErrorMessage(..), Symbol) -- | Class for verifying mock function. verify :: ( ResolvableMock m , Eq (ResolvableParamsOf m) , Show (ResolvableParamsOf m) ) => m -> VerifyMatchType (ResolvableParamsOf m) -> IO () verify m matchType = do ResolvedMock mockName recorder <- requireResolved m invocationList <- readInvocationList (invocationRef recorder) case doVerify mockName invocationList matchType of Nothing -> pure () Just (VerifyFailed msg) -> errorWithoutStackTrace msg `seq` pure () doVerify :: (Eq a, Show a) => Maybe MockName -> InvocationList a -> VerifyMatchType a -> Maybe VerifyFailed doVerify name list (MatchAny a) = do guard $ notElem a list pure $ verifyFailedMessage name list a doVerify name list (MatchAll a) = do guard $ Prelude.any (a /=) list pure $ verifyFailedMessage name list a readInvocationList :: TVar (InvocationRecord params) -> IO (InvocationList params) readInvocationList ref = do record <- readTVarIO ref pure $ invocations record -- | Verify that a resolved mock function was called at least once. -- This is used internally by typeclass mock verification. verifyResolvedAny :: ResolvedMock params -> IO () verifyResolvedAny (ResolvedMock mockName recorder) = do invocationList <- readInvocationList (invocationRef recorder) when (null invocationList) $ errorWithoutStackTrace $ intercalate "\n" [ "Function" <> mockNameLabel mockName <> " was never called" ] compareCount :: CountVerifyMethod -> Int -> Bool compareCount (Equal e) a = a == e compareCount (LessThanEqual e) a = a <= e compareCount (LessThan e) a = a < e compareCount (GreaterThanEqual e) a = a >= e compareCount (GreaterThan e) a = a > e verifyCount :: ( ResolvableMock m , Eq (ResolvableParamsOf m) ) => m -> ResolvableParamsOf m -> CountVerifyMethod -> IO () verifyCount m v method = do ResolvedMock mockName recorder <- requireResolved m invocationList <- readInvocationList (invocationRef recorder) let callCount = length (filter (v ==) invocationList) if compareCount method callCount then pure () else errorWithoutStackTrace $ countWithArgsMismatchMessage mockName method callCount -- | Generate error message for count mismatch with arguments countWithArgsMismatchMessage :: Maybe MockName -> CountVerifyMethod -> Int -> String countWithArgsMismatchMessage mockName method callCount = intercalate "\n" [ "function" <> mockNameLabel mockName <> " was not called the expected number of times with the expected arguments.", " expected: " <> show method, " but got: " <> show callCount ] verifyOrder :: (ResolvableMock m , Eq (ResolvableParamsOf m) , Show (ResolvableParamsOf m)) => VerifyOrderMethod -> m -> [ResolvableParamsOf m] -> IO () verifyOrder method m matchers = do ResolvedMock mockName recorder <- requireResolved m invocationList <- readInvocationList (invocationRef recorder) case doVerifyOrder method mockName invocationList matchers of Nothing -> pure () Just (VerifyFailed msg) -> errorWithoutStackTrace msg `seq` pure () doVerifyOrder :: (Eq a, Show a) => VerifyOrderMethod -> Maybe MockName -> InvocationList a -> [a] -> Maybe VerifyFailed doVerifyOrder ExactlySequence name calledValues expectedValues | length calledValues /= length expectedValues = do pure $ verifyFailedOrderParamCountMismatch name calledValues expectedValues | otherwise = do let unexpectedOrders = collectUnExpectedOrder calledValues expectedValues guard $ length unexpectedOrders > 0 pure $ verifyFailedSequence name unexpectedOrders doVerifyOrder PartiallySequence name calledValues expectedValues | length calledValues < length expectedValues = do pure $ verifyFailedOrderParamCountMismatch name calledValues expectedValues | otherwise = do guard $ isOrderNotMatched calledValues expectedValues pure $ verifyFailedPartiallySequence name calledValues expectedValues verifyFailedPartiallySequence :: Show a => Maybe MockName -> InvocationList a -> [a] -> VerifyFailed verifyFailedPartiallySequence name calledValues expectedValues = VerifyFailed $ intercalate "\n" [ "function" <> mockNameLabel name <> " was not called with the expected arguments in the expected order.", " expected order:", intercalate "\n" $ (" " <>) . show <$> expectedValues, " but got:", intercalate "\n" $ (" " <>) . show <$> calledValues ] isOrderNotMatched :: Eq a => InvocationList a -> [a] -> Bool isOrderNotMatched calledValues expectedValues = isNothing $ foldl ( \candidates e -> do candidates >>= \c -> do index <- elemIndex e c Just $ drop (index + 1) c ) (Just calledValues) expectedValues verifyFailedOrderParamCountMismatch :: Maybe MockName -> InvocationList a -> [a] -> VerifyFailed verifyFailedOrderParamCountMismatch name calledValues expectedValues = VerifyFailed $ intercalate "\n" [ "function" <> mockNameLabel name <> " was not called with the expected arguments in the expected order (count mismatch).", " expected: " <> show (length expectedValues), " but got: " <> show (length calledValues) ] verifyFailedSequence :: Show a => Maybe MockName -> [VerifyOrderResult a] -> VerifyFailed verifyFailedSequence name fails = VerifyFailed $ intercalate "\n" ( ("function" <> mockNameLabel name <> " was not called with the expected arguments in the expected order.") : (verifyOrderFailedMesssage <$> fails) ) collectUnExpectedOrder :: Eq a => InvocationList a -> [a] -> [VerifyOrderResult a] collectUnExpectedOrder calledValues expectedValues = catMaybes $ mapWithIndex ( \i expectedValue -> do let calledValue = calledValues !! i guard $ expectedValue /= calledValue pure VerifyOrderResult {index = i, calledValue = calledValue, expectedValue} ) expectedValues mapWithIndex :: (Int -> a -> b) -> [a] -> [b] mapWithIndex f xs = [f i x | (i, x) <- zip [0 ..] xs] -- Legacy shouldApply* helpers removed. Use shouldBeCalled API instead. type family PrependParam a rest where PrependParam a () = Param a PrependParam a rest = Param a :> rest type family FunctionParams fn where FunctionParams (a -> fn) = PrependParam a (FunctionParams fn) FunctionParams fn = () type family ResolvableParamsOf target :: Type where ResolvableParamsOf (a -> fn) = FunctionParams (a -> fn) ResolvableParamsOf target = () type family Or (a :: Bool) (b :: Bool) :: Bool where Or 'True _ = 'True Or _ 'True = 'True Or 'False 'False = 'False type family Not (a :: Bool) :: Bool where Not 'True = 'False Not 'False = 'True type family IsFunctionType target :: Bool where IsFunctionType (_a -> _b) = 'True IsFunctionType _ = 'False type family IsIOType target :: Bool where IsIOType (IO _) = 'True IsIOType _ = 'False type family IsPureConstant target :: Bool where IsPureConstant target = Not (Or (IsFunctionType target) (IsIOType target)) type family RequireCallable (fn :: Symbol) target :: Constraint where RequireCallable fn target = RequireCallableImpl fn (IsPureConstant target) target type family RequireCallableImpl (fn :: Symbol) (isPure :: Bool) target :: Constraint where RequireCallableImpl fn 'True target = TypeError ( 'Text fn ':<>: 'Text " is not available for pure constant mocks." ':$$: 'Text " target type: " ':<>: 'ShowType target ':$$: 'Text " hint: convert it into a callable mock or use shouldBeCalled with 'anything'." ) RequireCallableImpl _ 'False _ = () -- | Constraint alias for resolvable mock types. type ResolvableMock m = (Typeable (ResolvableParamsOf m), Typeable (InvocationRecorder (ResolvableParamsOf m))) -- | Constraint alias for resolvable mock types with specific params. type ResolvableMockWithParams m params = (ResolvableParamsOf m ~ params, ResolvableMock m) resolveForVerification :: forall target params. ( params ~ ResolvableParamsOf target , Typeable params , Typeable (InvocationRecorder params) ) => target -> IO (Maybe (Maybe MockName, InvocationRecorder params)) resolveForVerification target = do let fetch = lookupVerifierForFn target result <- case eqT :: Maybe (params :~: ()) of Just Refl -> withAllUnitGuards fetch Nothing -> fetch case result of Nothing -> pure Nothing Just (name, dynVerifier) -> case fromDynamic @(InvocationRecorder params) dynVerifier of Just verifier -> pure $ Just (name, verifier) Nothing -> pure Nothing -- | Verify that a function was called the expected number of times verifyCallCount :: Maybe MockName -> InvocationRecorder params -> CountVerifyMethod -> IO () verifyCallCount maybeName recorder method = do invocationList <- readInvocationList (invocationRef recorder) let callCount = length invocationList unless (compareCount method callCount) $ errorWithoutStackTrace $ countMismatchMessage maybeName method callCount -- | Generate error message for count mismatch countMismatchMessage :: Maybe MockName -> CountVerifyMethod -> Int -> String countMismatchMessage maybeName method callCount = intercalate "\n" [ "function" <> mockNameLabel maybeName <> " was not called the expected number of times.", " expected: " <> showCountMethod method, " but got: " <> show callCount ] where showCountMethod (Equal n) = show n showCountMethod (LessThanEqual n) = "<= " <> show n showCountMethod (GreaterThanEqual n) = ">= " <> show n showCountMethod (LessThan n) = "< " <> show n showCountMethod (GreaterThan n) = "> " <> show n verificationFailure :: IO a verificationFailure = errorWithoutStackTrace verificationFailureMessage data ResolvedMock params = ResolvedMock { resolvedMockName :: Maybe MockName, resolvedMockRecorder :: InvocationRecorder params } requireResolved :: forall target params. ( params ~ ResolvableParamsOf target , Typeable params , Typeable (InvocationRecorder params) ) => target -> IO (ResolvedMock params) requireResolved target = do resolveForVerification target >>= \case Just (name, recorder) -> pure $ ResolvedMock name recorder Nothing -> verificationFailure verificationFailureMessage :: String verificationFailureMessage = intercalate "\n" [ "Error: 'shouldBeCalled' can only verify functions created by 'mock'.", "", "The value you passed could not be recognized as a mock function.", "", "This usually happens in one of the following cases:", " - You passed a normal (non-mock) function.", " - You passed a stub or value not created via 'mock' / 'mockIO'.", " - You are trying to verify a value that was never registered as a mock.", "", "How to fix it:", " 1. Make sure you created the function with 'mock' (or 'mockIO' for IO)", " before calling 'shouldBeCalled'.", " 2. Pass that mock value directly to 'shouldBeCalled'", " (not the original function or a plain value).", "", "If this message still appears, check that:", " - You are not passing a pure constant.", " - The mock value is still in scope where 'shouldBeCalled' is used.", "", "Tip: If you prefer automatic verification,", "consider using 'withMock', which runs all expectations at the end", "of the block." ] -- ============================================ -- shouldBeCalled API -- ============================================ -- | Verification specification for shouldBeCalled data VerificationSpec params where -- | Count verification with specific arguments CountVerification :: CountVerifyMethod -> params -> VerificationSpec params -- | Count verification without arguments (any arguments) CountAnyVerification :: CountVerifyMethod -> VerificationSpec params -- | Order verification OrderVerification :: VerifyOrderMethod -> [params] -> VerificationSpec params -- | Simple verification with arguments (at least once) SimpleVerification :: params -> VerificationSpec params -- | Simple verification without arguments (at least once, any arguments) AnyVerification :: VerificationSpec params -- | Times condition for count verification newtype TimesSpec = TimesSpec CountVerifyMethod -- | Create a times condition for exact count. -- -- > f `shouldBeCalled` times 3 -- > f `shouldBeCalled` (times 3 `with` "arg") times :: Int -> TimesSpec times n = TimesSpec (Equal n) -- | Create a times condition for at least count (>=). -- -- > f `shouldBeCalled` atLeast 1 atLeast :: Int -> TimesSpec atLeast n = TimesSpec (GreaterThanEqual n) -- | Create a times condition for at most count (<=). -- -- > f `shouldBeCalled` atMost 2 atMost :: Int -> TimesSpec atMost n = TimesSpec (LessThanEqual n) -- | Create a times condition for greater than count (>). greaterThan :: Int -> TimesSpec greaterThan n = TimesSpec (GreaterThan n) -- | Create a times condition for less than count (<). lessThan :: Int -> TimesSpec lessThan n = TimesSpec (LessThan n) -- | Create a times condition for exactly once. -- Equivalent to 'times 1'. once :: TimesSpec once = TimesSpec (Equal 1) -- | Create a times condition for never (zero times). -- Equivalent to 'times 0'. never :: TimesSpec never = TimesSpec (Equal 0) -- | Order condition for order verification newtype OrderSpec = OrderSpec VerifyOrderMethod -- | Create an order condition for exact sequence inOrder :: OrderSpec inOrder = OrderSpec ExactlySequence -- | Create an order condition for partial sequence inPartialOrder :: OrderSpec inPartialOrder = OrderSpec PartiallySequence -- | Create a simple verification with arguments. -- This accepts both raw values and Param chains. -- -- > f `shouldBeCalled` calledWith "a" calledWith :: params -> VerificationSpec params calledWith = SimpleVerification -- | Create a simple verification without arguments. -- It verifies that the function was called at least once, with ANY arguments. -- -- > f `shouldBeCalled` anything anything :: forall params. VerificationSpec params anything = AnyVerification -- | Type class for combining times condition with arguments class WithArgs spec params where type WithResult spec params :: Type with :: spec -> params -> WithResult spec params -- | Instance for times condition with arguments instance (Eq params, Show params) => WithArgs TimesSpec params where type WithResult TimesSpec params = VerificationSpec params with (TimesSpec method) = CountVerification method -- | Type family to normalize argument types for 'withArgs' type family NormalizeWithArg a :: Type where NormalizeWithArg (Param a :> rest) = Param a :> rest NormalizeWithArg (Param a) = Param a NormalizeWithArg a = Param a -- | Type class to normalize argument types (to Param or Param chain) class ToNormalizedArg a where toNormalizedArg :: a -> NormalizeWithArg a instance ToNormalizedArg (Param a :> rest) where toNormalizedArg = id instance ToNormalizedArg (Param a) where toNormalizedArg = id instance {-# OVERLAPPABLE #-} (NormalizeWithArg a ~ Param a, WrapParam a) => ToNormalizedArg a where toNormalizedArg = wrap -- | New function for combining times condition with arguments (supports raw values) -- This will replace 'with' once the old 'with' is removed withArgs :: forall params. ( ToNormalizedArg params , Eq (NormalizeWithArg params) , Show (NormalizeWithArg params) ) => TimesSpec -> params -> VerificationSpec (NormalizeWithArg params) withArgs (TimesSpec method) args = CountVerification method (toNormalizedArg args) infixl 8 `withArgs` -- | Verify that the mock was called with the specified sequence of arguments in exact order. -- -- > f `shouldBeCalled` inOrderWith ["a", "b"] inOrderWith :: forall params. ( ToNormalizedArg params , Eq (NormalizeWithArg params) , Show (NormalizeWithArg params) ) => [params] -> VerificationSpec (NormalizeWithArg params) inOrderWith args = OrderVerification ExactlySequence (map toNormalizedArg args) -- | Verify that the mock was called with the specified sequence of arguments, allowing other calls in between. -- -- > f `shouldBeCalled` inPartialOrderWith ["a", "c"] -- > -- This passes if calls were: "a", "b", "c" inPartialOrderWith :: forall params. ( ToNormalizedArg params , Eq (NormalizeWithArg params) , Show (NormalizeWithArg params) ) => [params] -> VerificationSpec (NormalizeWithArg params) inPartialOrderWith args = OrderVerification PartiallySequence (map toNormalizedArg args) -- | Main verification function class class ShouldBeCalled m spec where shouldBeCalled :: HasCallStack => m -> spec -> IO () -- | Instance for times spec alone (without arguments) instance ( ResolvableMockWithParams m params , RequireCallable "shouldBeCalled" m ) => ShouldBeCalled m TimesSpec where shouldBeCalled m (TimesSpec method) = do ResolvedMock mockName verifier <- requireResolved m verifyCallCount mockName verifier method -- | Instance for VerificationSpec (handles all verification types) instance {-# OVERLAPPING #-} ( ResolvableMockWithParams m params , Eq params , Show params , RequireCallable "shouldBeCalled" m ) => ShouldBeCalled m (VerificationSpec params) where shouldBeCalled m spec = case spec of CountVerification method args -> verifyCount m args method CountAnyVerification count -> do ResolvedMock mockName recorder <- requireResolved m verifyCallCount mockName recorder count OrderVerification method argsList -> verifyOrder method m argsList SimpleVerification args -> verify m (MatchAny args) AnyVerification -> do ResolvedMock mockName recorder <- requireResolved m invocationList <- readInvocationList (invocationRef recorder) when (null invocationList) $ errorWithoutStackTrace $ intercalate "\n" [ "Function" <> mockNameLabel mockName <> " was never called" ] -- | Instance for Param chains (e.g., "a" ~> "b") instance {-# OVERLAPPING #-} ( ResolvableMockWithParams m (Param a :> rest) , Eq (Param a :> rest) , Show (Param a :> rest) ) => ShouldBeCalled m (Param a :> rest) where shouldBeCalled m args = verify m (MatchAny args) -- | Instance for single Param (e.g., param "a") instance {-# OVERLAPPING #-} ( ResolvableMockWithParams m (Param a) , Eq (Param a) , Show (Param a) ) => ShouldBeCalled m (Param a) where shouldBeCalled m args = verify m (MatchAny args) -- | Instance for raw values (e.g., "a") -- This converts raw values to Param at runtime instance {-# OVERLAPPABLE #-} ( ResolvableMockWithParams m (Param a) , Eq (Param a) , Show (Param a) , Show a , Eq a ) => ShouldBeCalled m a where shouldBeCalled m arg = verify m (MatchAny (param arg))