{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedRecordDot #-}

module Test.MockCat.TH.FunctionBuilder
  ( createMockBuilderVerifyParams
  , createMockBuilderFnType
  , MockFnContext(..)
  , MockFnBuilder(..)
  , buildMockFnContext
  , buildMockFnDeclarations
  , determineMockFnBuilder
  , createNoInlinePragma
  , doCreateMockFnDecs
  , doCreateConstantMockFnDecs
  , doCreateEmptyVerifyParamMockFnDecs
  , createMockBody
  , createTypeablePreds
  , partialAdditionalPredicates
  , createFnName
  , findParam
  , typeToNames
  , safeIndex
  , generateInstanceMockFnBody
  , generateInstanceRealFnBody
  , generateStubFn
  )
where

import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Language.Haskell.TH
  ( Dec (..),
    Exp (..),
    Name,
    Pred,
    Q,
    Quote,
    Type (..),
    TyVarBndr(..),
    Inline (NoInline),
    RuleMatch (FunLike),
    Phases (AllPhases),
    mkName,
    newName
  )
import Language.Haskell.TH.Lib
import Language.Haskell.TH.Syntax (nameBase, Specificity (SpecifiedSpec))
import Test.MockCat.Mock ( MockBuilder )
import qualified Test.MockCat.Internal.MockRegistry as Registry
import Test.MockCat.Internal.Builder (buildMock)
import Test.MockCat.Internal.Types (BuiltMock(..))
import Test.MockCat.Cons (Head(..), (:>)(..))
import Test.MockCat.MockT
  ( MockT (..),
    Definition (..),
    getDefinitions,
    addDefinition
  )
import Test.MockCat.TH.TypeUtils
  ( isNotConstantFunctionType,
    needsTypeable,
    collectTypeVars,
    collectTypeableTargets
  )
import Test.MockCat.TH.ContextBuilder
  ( MockType (..)
  )
import Test.MockCat.TH.ClassAnalysis
  ( VarAppliedType (..),
    updateType
  )
import Test.MockCat.Verify (ResolvableParamsOf)
import Data.Dynamic (Dynamic, toDyn)
import Data.Proxy (Proxy(..))
import Data.List (find, nubBy)
import Data.Typeable (Typeable)
import Language.Haskell.TH.Ppr (pprint)
import Unsafe.Coerce (unsafeCoerce)
import GHC.TypeLits (KnownSymbol, symbolVal)
 

import Test.MockCat.Param (Param, param)
import Test.MockCat.TH.Types (MockOptions(..))

createMockBuilderVerifyParams :: Type -> Type
createMockBuilderVerifyParams (AppT (AppT ArrowT ty) (AppT (VarT _) _)) =
  AppT (ConT ''Param) ty
createMockBuilderVerifyParams (AppT (AppT ArrowT ty) ty2) =
  AppT
    (AppT (ConT ''(:>)) (AppT (ConT ''Param) ty))
    (createMockBuilderVerifyParams ty2)
createMockBuilderVerifyParams (AppT (VarT _) _) = TupleT 0
createMockBuilderVerifyParams (AppT (ConT _) _) = TupleT 0
createMockBuilderVerifyParams (ForallT _ _ ty) = createMockBuilderVerifyParams ty
createMockBuilderVerifyParams (VarT _) = TupleT 0
createMockBuilderVerifyParams (ConT _) = TupleT 0
createMockBuilderVerifyParams _ = TupleT 0

createMockBuilderFnType :: Name -> Type -> Type
createMockBuilderFnType monadVarName a@(AppT (VarT var) ty)
  | monadVarName == var = ty
  | otherwise = a
createMockBuilderFnType monadVarName (AppT ty ty2) =
  AppT ty (createMockBuilderFnType monadVarName ty2)
createMockBuilderFnType monadVarName (ForallT _ _ ty) =
  createMockBuilderFnType monadVarName ty
createMockBuilderFnType _ ty = ty

partialAdditionalPredicates :: Type -> Type -> [Pred]
partialAdditionalPredicates funType verifyParams =
  [ AppT
      (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) funType))
      verifyParams
  | not (null (collectTypeVars funType))
  ]

-- Helper to create Typeable predicates using the smart collection logic
createTypeablePreds :: [Type] -> [Pred]
createTypeablePreds targets =
  [ AppT (ConT ''Typeable) t
  | t <- nubBy (\a b -> pprint a == pprint b) (concatMap collectTypeableTargets targets)
  , needsTypeable t
  ]


data MockFnContext = MockFnContext
  { mockType :: MockType,
    monadVarName :: Name,
    mockOptions :: MockOptions,
    originalType :: Type,
    fnNameStr :: String,
    mockFnName :: Name,
    paramsName :: Name,
    updatedType :: Type,
    fnType :: Type
  }

data MockFnBuilder = VariadicBuilder | ConstantImplicitBuilder | ConstantExplicitBuilder

buildMockFnContext ::
  MockType ->
  Name ->
  [VarAppliedType] ->
  MockOptions ->
  Name ->
  Type ->
  MockFnContext
buildMockFnContext mockType monadVarName varAppliedTypes mockOptions sigFnName ty =
  let fnNameStr = createFnName sigFnName mockOptions
      mockFnName = mkName fnNameStr
      params = mkName "p"
      updatedType = updateType ty varAppliedTypes
      fnType =
        if mockOptions.implicitMonadicReturn
          then createMockBuilderFnType monadVarName updatedType
          else updatedType
   in MockFnContext
        { mockType,
          monadVarName,
          mockOptions,
          originalType = ty,
          fnNameStr,
          mockFnName,
          paramsName = params,
          updatedType,
          fnType
        }

buildMockFnDeclarations :: MockFnContext -> Q [Dec]
buildMockFnDeclarations ctx@MockFnContext{mockType, fnNameStr, mockFnName, paramsName, fnType, monadVarName, updatedType} =
  case determineMockFnBuilder ctx of
    VariadicBuilder ->
      doCreateMockFnDecs mockType fnNameStr mockFnName paramsName fnType monadVarName updatedType
    ConstantImplicitBuilder ->
      doCreateConstantMockFnDecs mockType fnNameStr mockFnName fnType monadVarName
    ConstantExplicitBuilder ->
      doCreateEmptyVerifyParamMockFnDecs fnNameStr mockFnName paramsName fnType monadVarName updatedType

determineMockFnBuilder :: MockFnContext -> MockFnBuilder
determineMockFnBuilder ctx
  | isNotConstantFunctionType (originalType ctx) = VariadicBuilder
  | (mockOptions ctx).implicitMonadicReturn = ConstantImplicitBuilder
  | otherwise = ConstantExplicitBuilder

createNoInlinePragma :: Name -> Q Dec
createNoInlinePragma name = pragInlD name NoInline FunLike AllPhases

doCreateMockFnDecs :: (Quote m) => MockType -> String -> Name -> Name -> Type -> Name -> Type -> m [Dec]
doCreateMockFnDecs mockType funNameStr mockFunName params funType monadVarName updatedType = do
  newFunSig <- do
    let verifyParams = createMockBuilderVerifyParams updatedType
        mockBuilderPred =
          AppT (AppT (AppT (ConT ''MockBuilder) (VarT params)) funType) verifyParams
        eqConstraint =
          [ AppT
              (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) funType))
              verifyParams
          | not (null (collectTypeVars funType))
          ]
        baseCtx =
          ([mockBuilderPred | verifyParams /= TupleT 0])
            ++ [AppT (ConT ''MonadIO) (VarT monadVarName)]
        typeablePreds = createTypeablePreds [funType, verifyParams]
        ctx = case mockType of
          Partial ->
            baseCtx ++ partialAdditionalPredicates funType verifyParams ++ typeablePreds
          Total ->
            baseCtx ++ eqConstraint ++ typeablePreds
        resultType =
          AppT
            (AppT ArrowT (VarT params))
            (AppT (AppT (ConT ''MockT) (VarT monadVarName)) funType)
    sigD mockFunName (pure (ForallT [] ctx resultType))

  mockBody <- createMockBody funNameStr [|p|] funType
  newFun <- funD mockFunName [clause [varP $ mkName "p"] (normalB (pure mockBody)) []]

  pure $ newFunSig : [newFun]

doCreateConstantMockFnDecs :: (Quote m) => MockType -> String -> Name -> Type -> Name -> m [Dec]
doCreateConstantMockFnDecs Partial funNameStr mockFunName _ monadVarName = do
  stubVar <- newName "r"
  let ctx =
        [ AppT
            (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) (VarT stubVar)))
            (TupleT 0)
        , AppT (ConT ''MonadIO) (VarT monadVarName)
        , AppT (ConT ''Typeable) (VarT stubVar)
        , AppT (ConT ''Show) (VarT stubVar)
        , AppT (ConT ''Eq) (VarT stubVar)
        ]
      resultType =
        AppT
          (AppT ArrowT (VarT stubVar))
          (AppT (AppT (ConT ''MockT) (VarT monadVarName)) (VarT stubVar))
  newFunSig <-
    sigD
      mockFunName
      ( pure
          (ForallT
              [ PlainTV stubVar SpecifiedSpec
              , PlainTV monadVarName SpecifiedSpec
              ]
              ctx
              resultType
          )
      )
  headParam <- [|Head :> param p|]
  mockBody <- createMockBody funNameStr (pure headParam) (VarT stubVar)
  newFun <- funD mockFunName [clause [varP $ mkName "p"] (normalB (pure mockBody)) []]
  pure $ newFunSig : [newFun]
doCreateConstantMockFnDecs Total funNameStr mockFunName ty monadVarName = do
  (newFunSig, funTypeForBody) <- case ty of
    AppT (ConT _) (VarT mv) | mv == monadVarName -> do
      a <- newName "a"
      let ctx =
            [ AppT (ConT ''MonadIO) (VarT monadVarName)
            , AppT (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) (VarT a))) (TupleT 0)
            , AppT (ConT ''Typeable) (VarT a)
            , AppT (ConT ''Show) (VarT a)
            , AppT (ConT ''Eq) (VarT a)
            ]
          resultType =
            AppT
              (AppT ArrowT (VarT a))
              (AppT (AppT (ConT ''MockT) (VarT monadVarName)) (VarT a))
      sig <- sigD
        mockFunName
        ( pure
            (ForallT
                [PlainTV a SpecifiedSpec, PlainTV monadVarName SpecifiedSpec]
                ctx
                resultType
            )
        )
      pure (sig, VarT a)
    _ -> do
      let headParamType = AppT (AppT (ConT ''(:>)) (ConT ''Head)) (AppT (ConT ''Param) ty)
          verifyParams' = createMockBuilderVerifyParams ty
          mockBuilderPred' = AppT (AppT (AppT (ConT ''MockBuilder) headParamType) ty) (TupleT 0)
          ctx =
            [ AppT (ConT ''MonadIO) (VarT monadVarName)
            ]
            ++ ([mockBuilderPred' | verifyParams' /= TupleT 0])
            ++ createTypeablePreds [ty]
          resultType =
            AppT
              (AppT ArrowT ty)
              (AppT (AppT (ConT ''MockT) (VarT monadVarName)) ty)
      sig <- sigD mockFunName (pure (ForallT [PlainTV monadVarName SpecifiedSpec] ctx resultType))
      pure (sig, ty)
  headParam <- [|Head :> param p|]
  mockBody <- createMockBody funNameStr (pure headParam) funTypeForBody
  newFun <- funD mockFunName [clause [varP $ mkName "p"] (normalB (pure mockBody)) []]
  pure $ newFunSig : [newFun]

doCreateEmptyVerifyParamMockFnDecs :: (Quote m) => String -> Name -> Name -> Type -> Name -> Type -> m [Dec]
doCreateEmptyVerifyParamMockFnDecs funNameStr mockFunName params funType monadVarName updatedType = do
  newFunSig <- do
    let verifyParams = createMockBuilderVerifyParams updatedType
        mockBuilderPred = AppT (AppT (AppT (ConT ''MockBuilder) (VarT params)) funType) verifyParams
        eqConstraint =
          [ AppT
              (AppT EqualityT (AppT (ConT ''ResolvableParamsOf) funType))
              verifyParams
          | not (null (collectTypeVars funType))
          ]
        ctx =
          [mockBuilderPred]
            ++ [AppT (ConT ''MonadIO) (VarT monadVarName)]
            ++ eqConstraint
            ++ createTypeablePreds [funType, verifyParams]
        resultType =
          AppT
            (AppT ArrowT (VarT params))
            (AppT (AppT (ConT ''MockT) (VarT monadVarName)) funType)
    sigD mockFunName (pure (ForallT [] ctx resultType))

  mockBody <- createMockBody funNameStr [|p|] funType
  newFun <- funD mockFunName [clause [varP $ mkName "p"] (normalB (pure mockBody)) []]

  pure $ newFunSig : [newFun]

createMockBody :: (Quote m) => String -> m Exp -> Type -> m Exp
createMockBody funNameStr paramsExp _funType = do
  params <- paramsExp
  [|
    MockT $ do
      -- Build the mock instance and its verifier directly so we have access
      -- to the verifier value (avoids runtime type-mismatch when resolving).
      BuiltMock { builtMockFn = mockInstance, builtMockRecorder = verifier } <- liftIO $ buildMock (Just $(litE (stringL funNameStr))) $(pure params)
      -- Register and get the canonical wrapper (preserved for async safety)
      canonicalInstance <- liftIO $ Registry.register (Just $(litE (stringL funNameStr))) verifier mockInstance
      addDefinition
        ( Definition
            (Proxy :: Proxy $(litT (strTyLit funNameStr)))
            canonicalInstance
            NoVerification
        )
      pure canonicalInstance
    |]

createFnName :: Name -> MockOptions -> String
createFnName funName opts = do
  opts.prefix <> nameBase funName <> opts.suffix

findParam :: KnownSymbol sym => Proxy sym -> [Definition] -> Maybe Dynamic
findParam pa definitions = do
  let definition = find (\(Definition s _ _) -> symbolVal s == symbolVal pa) definitions
  fmap (\(Definition _ mockFunction _) -> toDyn mockFunction) definition

typeToNames :: Type -> [Q Name]
typeToNames (AppT (AppT ArrowT _) t2) = newName "a" : typeToNames t2
typeToNames (ForallT _ _ ty) = typeToNames ty
typeToNames _ = []

safeIndex :: [a] -> Int -> Maybe a
safeIndex [] _ = Nothing
safeIndex (x : _) 0 = Just x
safeIndex (_ : xs) n
  | n < 0 = Nothing
  | otherwise = safeIndex xs (n - 1)


generateInstanceMockFnBody :: String -> [Q Exp] -> Name -> MockOptions -> Q Exp
generateInstanceMockFnBody fnNameStr args r opts = do
  returnExp <- if opts.implicitMonadicReturn
    then [| pure $(varE r) |]
    else [| lift $(varE r) |]

  [|
    MockT $ do
      defs <- getDefinitions
      let findDef = find (\(Definition s _ _) -> symbolVal s == $(litE (stringL fnNameStr))) defs
      case findDef of
        Just (Definition _ mf _) -> do
          let mock = unsafeCoerce mf
          let $(bangP $ varP r) = $(generateStubFn args [|mock|])
          $(pure returnExp)
        Nothing -> error $ "no answer found stub function `" ++ fnNameStr ++ "`."
    |]

generateInstanceRealFnBody :: Name -> String -> [Q Exp] -> Name -> MockOptions -> Q Exp
generateInstanceRealFnBody fnName fnNameStr args r opts = do
  returnExp <- if opts.implicitMonadicReturn
    then [| pure $(varE r) |]
    else [| lift $(varE r) |]
  [|
    MockT $ do
      defs <- getDefinitions
      let findDef = find (\(Definition s _ _) -> symbolVal s == $(litE (stringL fnNameStr))) defs
      case findDef of
        Just (Definition _ mf _) -> do
          let mock = unsafeCoerce mf
          let $(bangP $ varP r) = $(generateStubFn args [|mock|])
          $(pure returnExp)
        Nothing -> lift $ $(foldl appE (varE fnName) args)
    |]

generateStubFn :: [Q Exp] -> Q Exp -> Q Exp
generateStubFn [] mock = mock
generateStubFn args mock = foldl appE mock args


