{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- | This module provides types and functions for representing mock parameters.
-- Parameters are used both for setting up expectations and for verification.
module Test.MockCat.Param
  ( Param(..),
    EqParams(..),
    WrapResult(wrapResult),
    value,
    param,
    ConsGen(..),
    MockSpec(..),
    expect,
    expect_,
    ToParamParam(..),
    ToParamArg(..),
    Normalize,
    any,
    ArgsOf,
    ProjectionArgs,
    projArgs,
    ReturnOf,
    ProjectionReturn,
    projReturn,
    returnValue,

  )
where

import Test.MockCat.Cons ((:>) (..), Head(..))
import Test.MockCat.Internal.Types (Cases)
import Unsafe.Coerce (unsafeCoerce)
import Prelude hiding (any)
import Data.Typeable (Typeable, typeOf)
import Foreign.Ptr (Ptr, ptrToIntPtr, castPtr, IntPtr)
import qualified Data.Text as T (Text)

infixr 1 ~>

-- | MockSpec wraps stub parameters with optional expectations.
-- The 'exps' type parameter is () when no expectations are set,
-- or a list of expectations when 'expects' has been applied.
-- This design ensures 'expects' can only be applied to MockSpec,
-- not to the result of 'mock'.
data MockSpec params exps = MockSpec
  { specParams :: params
  , specExpectations :: exps
  } deriving (Show, Eq)

data Param v where
  -- | A parameter that expects a specific value.
  ExpectValue :: (Show v, Eq v) => v -> String -> Param v
  -- | A parameter that expects a value satisfying a condition.
  ExpectCondition :: (v -> Bool) -> String -> Param v
  -- | A parameter that wraps a value without Eq or Show constraints.
  ValueWrapper :: v -> String -> Param v


-- | Class for wrapping raw values into Param for results.
-- Does not require Eq or Show, but will use them if available for better display.
class WrapResult a where
  wrapResult :: a -> Param a

instance {-# OVERLAPPING #-} WrapResult String where
  wrapResult s = ExpectValue s (show s)

instance {-# OVERLAPPING #-} WrapResult Int where
  wrapResult v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} WrapResult Integer where
  wrapResult v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} WrapResult Bool where
  wrapResult v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} WrapResult Double where
  wrapResult v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} WrapResult Float where
  wrapResult v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} WrapResult Char where
  wrapResult v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} WrapResult T.Text where
  wrapResult v = ExpectValue v (show v)

instance {-# OVERLAPPABLE #-} (Show a, Eq a) => WrapResult (Maybe a) where
  wrapResult v = ExpectValue v (show v)

instance {-# OVERLAPPABLE #-} WrapResult a where
  wrapResult v = ValueWrapper v "ValueWrapper"

instance Eq (Param a) where
  ExpectValue a _ == ExpectValue b _ = a == b
  ExpectValue a _ == ValueWrapper b _ = a == b
  ExpectValue a _ == ExpectCondition m _ = m a
  ValueWrapper a _ == ValueWrapper b _ = compareFunction a b
  ValueWrapper a _ == ExpectValue b _ = a == b
  ValueWrapper a _ == ExpectCondition m _ = m a
  ExpectCondition m _ == ExpectValue b _ = m b
  ExpectCondition m _ == ValueWrapper a _ = m a
  ExpectCondition _ "any" == ExpectCondition _ _ = True
  ExpectCondition _ _ == ExpectCondition _ "any" = True
  ExpectCondition _ l1 == ExpectCondition _ l2 = l1 == l2

instance Show (Param v) where
  show (ExpectValue _ l) = l
  show (ExpectCondition _ l) = l
  show (ValueWrapper _ l) = l

value :: Param v -> v
value (ExpectValue a _) = a
value (ValueWrapper a _) = a
value _ = error "not implemented"

-- | Create a Param from a value. Requires Eq and Show.
param :: (Show v, Eq v) => v -> Param v
param v = ExpectValue v (show v)


-- | Type family to untie the knot for ConsGen instances
type family Normalize a where
  Normalize (a :> b) = a :> b
  Normalize Head = Head
  Normalize (Param a) = Param a
  Normalize a = Param a

class ToParamArg a where
  toParamArg :: a -> Normalize a


instance {-# OVERLAPPING #-} (Typeable (a -> b)) => ToParamArg (a -> b) where
  toParamArg f = ValueWrapper f (showFunction f)

class EqParams a where
  eqParams :: a -> a -> Bool

instance {-# OVERLAPPING #-} EqParams (Param a) where
  eqParams = (==)

instance {-# OVERLAPPING #-} (EqParams a, EqParams b) => EqParams (a :> b) where
  (a1 :> b1) `eqParams` (a2 :> b2) = a1 `eqParams` a2 && b1 `eqParams` b2

instance {-# OVERLAPPING #-} EqParams Head where
  eqParams _ _ = True

instance EqParams () where
  eqParams _ _ = True



instance {-# OVERLAPPING #-} ToParamArg Head where
  toParamArg = id

instance {-# OVERLAPPING #-} ToParamArg Int where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} ToParamArg Integer where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} ToParamArg Double where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} ToParamArg Float where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} ToParamArg Bool where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} ToParamArg Char where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} ToParamArg Word where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} ToParamArg T.Text where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} (Show a, Eq a) => ToParamArg [a] where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPING #-} (Show a, Eq a) => ToParamArg (Maybe a) where
  toParamArg v = ExpectValue v (show v)

instance {-# OVERLAPPABLE #-} (Normalize a ~ Param a, Typeable a, Show a) => ToParamArg a where
  toParamArg v = ValueWrapper v (show v)

class ToParamParam a where
  toParamParam :: a -> Normalize a

instance {-# OVERLAPPING #-} ToParamParam (Param a) where
  toParamParam = id

instance {-# OVERLAPPING #-} (Typeable (a -> b)) => ToParamParam (a -> b) where
  toParamParam f = ValueWrapper f (showFunction f)

instance {-# OVERLAPPABLE #-} (Normalize a ~ Param a, Show a, Eq a) => ToParamParam a where
  toParamParam v = ExpectValue v (show v)

class ToParamResult b where
  toParamResult :: b -> Normalize b

instance {-# OVERLAPPING #-} ToParamResult (Param a) where
  toParamResult = id

instance {-# OVERLAPPING #-} ToParamResult (a :> b) where
  toParamResult = id

instance {-# OVERLAPPABLE #-} (Normalize b ~ Param b, WrapResult b) => ToParamResult b where
  toParamResult = wrapResult

class ConsGen a b where
  (~>) :: a -> b -> Normalize a :> Normalize b

-- | Instance for chaining parameters
instance (ToParamParam a, ToParamResult b) => ConsGen a b where
  a ~> b = toParamParam a :> toParamResult b

-- | Make a parameter to which any value is expected to apply.
--   Use with type application to specify the type: @any \@String@
--
--   > f <- mock $ any ~> True
any :: forall a. Param a
any = ExpectCondition (const True) "any"

{- | Create a conditional parameter with a label.
    When calling a mock function, if the argument does not satisfy this condition, an error occurs.
    In this case, the specified label is included in the error message.

    > expect (>5) ">5"
-}
expect :: (a -> Bool) -> String -> Param a
expect = ExpectCondition

{- | Create a conditional parameter without a label.
  The error message is displayed as "[some condition]".

  > expect_ (>5)
-}
expect_ :: (a -> Bool) -> Param a
expect_ f = ExpectCondition f "[some condition]"

-- | The type of the argument parameters of the parameters.
type family ArgsOf params where
  ArgsOf (Head :> Param r) = ()                        -- Constant value has no arguments
  ArgsOf (IO a) = ()
  ArgsOf (Param a :> Param r) = Param a
  ArgsOf (Param a :> rest) = Param a :> ArgsOf rest
  ArgsOf (Cases a b) = ArgsOf a
  ArgsOf a = ()

-- | Class for projecting the arguments of the parameter.
class ProjectionArgs params where
  projArgs :: params -> ArgsOf params

instance {-# OVERLAPPING #-} ProjectionArgs (Head :> Param r) where
  projArgs (_ :> _) = ()

instance {-# OVERLAPPING #-} ProjectionArgs (Param a :> Param r) where
  projArgs (a :> _) = a

instance
  {-# OVERLAPPABLE #-}
  (ProjectionArgs rest, ArgsOf (Param a :> rest) ~ (Param a :> ArgsOf rest)) =>
  ProjectionArgs (Param a :> rest) where
  projArgs (a :> rest) = a :> projArgs rest

-- | The type of the return parameter of the parameters.
type family ReturnOf params where
  ReturnOf (Head :> Param r) = Param r                 -- Constant value returns Param r
  ReturnOf (Param a :> Param r) = Param r
  ReturnOf (Param a :> rest) = ReturnOf rest

class ProjectionReturn param where
  projReturn :: param -> ReturnOf param

instance {-# OVERLAPPING #-} ProjectionReturn (Head :> Param r) where
  projReturn (_ :> r) = r

instance {-# OVERLAPPING #-} ProjectionReturn (Param a :> Param r) where
  projReturn (_ :> r) = r

instance
  {-# OVERLAPPABLE #-}
  (ProjectionReturn rest, ReturnOf (Param a :> rest) ~ ReturnOf rest) =>
  ProjectionReturn (Param a :> rest) where
  projReturn (_ :> rest) = projReturn rest

returnValue :: (ProjectionReturn params, ReturnOf params ~ Param r) => params -> r
returnValue = value . projReturn

-- | Get the pointer address of a value (used for both comparison and display)
getPtrAddr :: forall a. a -> IntPtr
getPtrAddr x = ptrToIntPtr (castPtr (unsafeCoerce x :: Ptr ()))

-- | Helper function to compare function values using pointer equality
-- Uses the same pointer calculation as showFunction for consistency
compareFunction :: forall a. a -> a -> Bool
compareFunction x y = getPtrAddr x == getPtrAddr y

-- | Show function using type information and a pointer hash
showFunction :: forall a. Typeable a => a -> String
showFunction x =
  let typeStr = show (typeOf x)
      -- Use the same pointer address calculation as compareFunction
      ptrAddr = show (getPtrAddr x)
   in typeStr ++ "@" ++ ptrAddr