{-|
Module      :  Language.Haskell.TH.TestUtils
Maintainer  :  Brandon Chinn <brandon@leapyear.io>
Stability   :  experimental
Portability :  portable

This module defines utilites for testing Template Haskell code.
-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Language.Haskell.TH.TestUtils
  ( -- * Configuring TestQ
    QState(..)
  , MockedMode(..)
  , QMode(..)
  , ReifyInfo(..)
  , loadNames
  , unmockedState
    -- * Running TestQ
  , runTestQ
  , runTestQErr
  , tryTestQ
  ) where

#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail(..))
#endif
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as Except
import qualified Control.Monad.Trans.Reader as Reader
import qualified Control.Monad.Trans.State as State
import Data.Maybe (fromMaybe)
import Language.Haskell.TH (Name, Q, runIO, runQ)
import Language.Haskell.TH.Syntax (Quasi(..), mkNameU)

import Language.Haskell.TH.TestUtils.QMode
import Language.Haskell.TH.TestUtils.QState

runTestQ :: forall mode a. IsMockedMode mode => QState mode -> Q a -> TestQResult mode a
runTestQ state = fmapResult' (either error id) . tryTestQ state
  where
    fmapResult' = fmapResult @mode @(Either String a) @a

runTestQErr :: forall mode a. (IsMockedMode mode, Show a) => QState mode -> Q a -> TestQResult mode String
runTestQErr state = fmapResult' (either id (error . mkMsg)) . tryTestQ state
  where
    fmapResult' = fmapResult @mode @(Either String a) @String
    mkMsg a = "Unexpected success: " ++ show a

tryTestQ :: forall mode a. IsMockedMode mode => QState mode -> Q a -> TestQResult mode (Either String a)
tryTestQ state = runResult @mode . runTestQMonad . runQ
  where
    runTestQMonad =
      Except.runExceptT
      . (`State.evalStateT` initialInternalState)
      . (`Reader.runReaderT` state)
      . unTestQ

    initialInternalState = InternalState
      { lastErrorReport = Nothing
      , newNameCounter = 0
      }

data InternalState = InternalState
  { lastErrorReport :: Maybe String
  , newNameCounter  :: Int
  }

newtype TestQ (mode :: MockedMode) a = TestQ
  { unTestQ
      :: Reader.ReaderT (QState mode)
          ( State.StateT InternalState
              ( Except.ExceptT String
                  Q
              )
          )
          a
  } deriving (Functor, Applicative, Monad)

{- TestQ stack: ReaderT -}

getState :: TestQ mode (QState mode)
getState = TestQ Reader.ask

getMode :: TestQ mode (QMode mode)
getMode = mode <$> getState

lookupReifyInfo :: (ReifyInfo -> a) -> Name -> TestQ mode a
lookupReifyInfo f name = do
  QState{reifyInfo} <- getState
  case lookup name reifyInfo of
    Just info -> return $ f info
    Nothing -> error $ "Cannot reify " ++ show name ++ " (did you mean to add it to reifyInfo?)"

{- TestQ stack: StateT -}

getLastError :: TestQ mode (Maybe String)
getLastError = TestQ . lift $ State.gets lastErrorReport

storeLastError :: String -> TestQ mode ()
storeLastError msg = TestQ . lift $ State.modify (\state -> state { lastErrorReport = Just msg })

getAndIncrementNewNameCounter :: TestQ mode Int
getAndIncrementNewNameCounter = TestQ . lift $ State.state $ \state ->
  let n = newNameCounter state
  in (n, state { newNameCounter = n + 1 })

{- TestQ stack: ExceptT -}

throwError :: String -> TestQ mode a
throwError = TestQ . lift . lift . Except.throwE

catchError :: TestQ mode a -> (String -> TestQ mode a) -> TestQ mode a
catchError (TestQ action) handler = TestQ $ catchE' action (unTestQ . handler)
  where
    catchE' = Reader.liftCatch (State.liftCatch Except.catchE)

{- TestQ stack: Q -}

liftQ :: Q a -> TestQ mode a
liftQ = TestQ . lift . lift . lift

{- Instances -}

instance MonadIO (TestQ mode) where
  liftIO = liftQ . runIO

instance MonadFail (TestQ mode) where
  fail msg = do
    -- The implementation of 'fail' for Q will send the message to qReport before calling 'fail'.
    -- Check to see if qReport put any message in the state and throw that message if so.
    lastMessage <- getLastError
    throwError $ fromMaybe msg lastMessage

-- | A helper to override Quasi methods when mocked and passthrough when not.
use :: Override mode a -> TestQ mode a
use Override{..} = do
  mode <- getMode
  case (mode, whenMocked) of
    (AllowQ, _)            -> liftQ whenAllowed
    (_, DoInstead testQ)   -> testQ
    (_, Unsupported label) -> error $ "Cannot run '" ++ label ++ "' with TestQ"

data Override mode a = Override
  { whenAllowed :: Q a
  , whenMocked  :: WhenMocked mode a
  }

data WhenMocked mode a
  = DoInstead (TestQ mode a)
  | Unsupported String

instance Quasi (TestQ mode) where
  {- IO -}

  qRunIO io = getMode >>= \case
    MockQ -> error "IO actions not allowed"
    _ -> liftIO io

  {- Error handling + reporting -}

  qRecover handler action = action `catchError` const handler

  qReport False msg = use Override
    { whenAllowed = qReport False msg
    , whenMocked = DoInstead $ return ()
    }
  qReport True msg = storeLastError msg

  {- Names -}

  qNewName name = use Override
    { whenAllowed = qNewName name
    , whenMocked = DoInstead $ mkNameU name . fromIntegral <$> getAndIncrementNewNameCounter
    }

  qLookupName b name = use Override
    { whenAllowed = qLookupName b name
    , whenMocked = DoInstead $ do
        QState{knownNames} <- getState
        return $ lookup name knownNames
    }

  {- ReifyInfo -}

  qReify name = use Override
    { whenAllowed = qReify name
    , whenMocked = DoInstead $ lookupReifyInfo reifyInfoInfo name
    }

  qReifyFixity name = use Override
    { whenAllowed = qReifyFixity name
    , whenMocked = DoInstead $ lookupReifyInfo reifyInfoFixity name
    }

  qReifyRoles name = use Override
    { whenAllowed = qReifyRoles name
    , whenMocked = DoInstead $ lookupReifyInfo reifyInfoRoles name >>= \case
        Nothing -> error $ "No roles associated with " ++ show name
        Just roles -> return roles
    }

#if MIN_VERSION_template_haskell(2,16,0)
  qReifyType name = use Override
    { whenAllowed = qReifyType name
    , whenMocked = DoInstead $ lookupReifyInfo reifyInfoType name
    }
#endif

  {- Currently unsupported -}

  qReifyInstances name types = use Override
    { whenAllowed = qReifyInstances name types
    , whenMocked = Unsupported "qReifyInstances"
    }
  qReifyAnnotations annlookup = use Override
    { whenAllowed = qReifyAnnotations annlookup
    , whenMocked = Unsupported "qReifyAnnotations"
    }
  qReifyModule mod' = use Override
    { whenAllowed = qReifyModule mod'
    , whenMocked = Unsupported "qReifyModule"
    }
  qReifyConStrictness name = use Override
    { whenAllowed = qReifyConStrictness name
    , whenMocked = Unsupported "qReifyConStrictness"
    }
  qLocation = use Override
    { whenAllowed = qLocation
    , whenMocked = Unsupported "qLocation"
    }
  qAddDependentFile fp = use Override
    { whenAllowed = qAddDependentFile fp
    , whenMocked = Unsupported "qAddDependentFile"
    }
  qAddTopDecls decls = use Override
    { whenAllowed = qAddTopDecls decls
    , whenMocked = Unsupported "qAddTopDecls"
    }
  qAddModFinalizer q = use Override
    { whenAllowed = qAddModFinalizer q
    , whenMocked = Unsupported "qAddModFinalizer"
    }
  qGetQ = use Override
    { whenAllowed = qGetQ
    , whenMocked = Unsupported "qGetQ"
    }
  qPutQ a = use Override
    { whenAllowed = qPutQ a
    , whenMocked = Unsupported "qPutQ"
    }
  qIsExtEnabled ext = use Override
    { whenAllowed = qIsExtEnabled ext
    , whenMocked = Unsupported "qIsExtEnabled"
    }
  qExtsEnabled = use Override
    { whenAllowed = qExtsEnabled
    , whenMocked = Unsupported "qExtsEnabled"
    }

#if MIN_VERSION_template_haskell(2,13,0)
  qAddCorePlugin plugin = use Override
    { whenAllowed = qAddCorePlugin plugin
    , whenMocked = Unsupported "qAddCorePlugin"
    }
#endif

#if MIN_VERSION_template_haskell(2,14,0)
  qAddTempFile suffix = use Override
    { whenAllowed = qAddTempFile suffix
    , whenMocked = Unsupported "qAddTempFile"
    }
  qAddForeignFilePath lang fp = use Override
    { whenAllowed = qAddForeignFilePath lang fp
    , whenMocked = Unsupported "qAddForeignFilePath"
    }
#elif MIN_VERSION_template_haskell(2,12,0)
  qAddForeignFile lang fp = use Override
    { whenAllowed = qAddForeignFile lang fp
    , whenMocked = Unsupported "qAddForeignFile"
    }
#endif