{-# OPTIONS_HADDOCK hide, not-home #-}

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}

module Test.Fixie.Internal.TH where

import qualified Control.Monad.Fail as Fail

import Control.Monad (join, replicateM, when, zipWithM)
import Test.Fixie.Internal (FixieT, Call(..), Function(..), unimplemented, captureCall, getFunction)
import Data.Char (isPunctuation, isSymbol)
import Data.Default.Class (Default(..))
import Data.List (foldl', nub, partition)
import Data.Text (pack)
import GHC.Exts (Constraint)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax

{-|
  A Template Haskell function that generates a fixture record type with a given
  name that reifies the set of typeclass dictionaries provided, as described in
  the module documentation for "Control.Monad.Fixie.TH". For example, the
  following splice would create a new record type called @Fixture@ with fields
  and instances for typeclasses called @Foo@ and @Bar@:

  > mkFixture "Fixture" [ts| Foo, Bar |]

  'mkFixture' supports types in the same format that @deriving@ clauses do when
  used with the @GeneralizedNewtypeDeriving@ GHC extension, so deriving
  multi-parameter typeclasses is possible if they are partially applied. For
  example, the following is valid:

  > class MultiParam a m where
  >   doSomething :: a -> m ()
  >
  > mkFixture "Fixture" [ts| MultiParam String |]
-}
mkFixture :: String -> [Type] -> Q [Dec]
mkFixture fixtureNameStr classTypes = do
  let fixtureName = mkName fixtureNameStr
  mapM_ assertDerivableConstraint classTypes

  (fixtureDec, fixtureFields) <- mkFixtureRecord fixtureName classTypes
  defaultInstanceDec <- mkDefaultInstance fixtureName fixtureFields

  instanceDecs <- traverse (flip mkInstance fixtureName) classTypes

  return ([fixtureDec, defaultInstanceDec] ++ instanceDecs)

mkFixtureRecord :: Name -> [Type] -> Q (Dec, [VarStrictType])
mkFixtureRecord fixtureName classTypes = do
  let classNames = map unappliedTypeName classTypes
  info <- traverse reify classNames
  methods <- traverse classMethods info

  mVar <- newName "m"
  fixtureFields <- join <$> zipWithM (methodsToFields mVar) classTypes methods
  let fixtureCs = [RecC fixtureName fixtureFields]

  let mKind = AppT (AppT ArrowT StarT) StarT
  let fixtureDec = mkDataD [] fixtureName [KindedTV mVar mKind] fixtureCs
  return (fixtureDec, fixtureFields)

mkDefaultInstance :: Name -> [VarStrictType] -> Q Dec
mkDefaultInstance fixtureName fixtureFields = do
  varName <- newName "m"
  let appliedFixtureT = AppT (ConT fixtureName) (VarT varName)

  let fieldNames = map (\(name, _, _) -> name) fixtureFields
  let fixtureClauses = map unimplementedField fieldNames

  let defImpl = RecConE fixtureName fixtureClauses
  let defDecl = FunD 'def [Clause [] (NormalB defImpl) []]

  return $ mkInstanceD [] (AppT (ConT ''Default) appliedFixtureT) [defDecl]

mkInstance :: Type -> Name -> Q Dec
mkInstance classType fixtureName = do
  eVar <- VarT <$> newName "e"
  mVar <- VarT <$> newName "m"

  let fixtureWithoutVarsT = AppT (ConT ''FixieT) (ConT fixtureName)
  let fixtureT = AppT (AppT fixtureWithoutVarsT eVar) mVar
  let instanceHead = AppT classType fixtureT

  classInfo <- reify (unappliedTypeName classType)
  methods <- case classInfo of
    ClassI (ClassD _ _ _ _ methods) _ -> return methods
    _ -> fail $ "mkInstance: expected a class type, given " ++ show classType
  funDecls <- traverse mkDictInstanceFunc methods

  return $ mkInstanceD [AppT (ConT ''Monad) mVar] instanceHead funDecls

{-|
  Ensures that a provided constraint is something test-fixture can actually
  derive an instance for. Specifically, it must be a constraint of kind
  * -> Constraint, and anything else is invalid.
-}
assertDerivableConstraint :: Type -> Q ()
assertDerivableConstraint classType = do
  info <- reify $ unappliedTypeName classType
  (ClassD _ _ classVars _ _) <- case info of
    ClassI dec _ -> return dec
    _ -> fail $ "mkFixture: expected a constraint, given ‘" ++ show (ppr classType) ++ "’"

  let classArgs = typeArgs classType
  let mkClassKind vars = foldr (\a b -> AppT (AppT ArrowT a) b) (ConT ''Constraint) (reverse varKinds)
        where varKinds = map (\(KindedTV _ k) -> k) vars
      constraintStr = show (ppr (ConT ''Constraint))

  when (length classArgs > length classVars) $
    fail $ "mkFixture: too many arguments for class\n"
        ++ "      in: " ++ show (ppr classType) ++ "\n"
        ++ "      for class of kind: " ++ show (ppr (mkClassKind classVars))

  when (length classArgs == length classVars) $
    fail $ "mkFixture: cannot derive instance for fully saturated constraint\n"
        ++ "      in: " ++ show (ppr classType) ++ "\n"
        ++ "      expected: * -> " ++ constraintStr ++ "\n"
        ++ "      given: " ++ constraintStr

  when (length classArgs < length classVars - 1) $
    fail $ "mkFixture: cannot derive instance for multi-parameter typeclass\n"
        ++ "      in: " ++ show (ppr classType) ++ "\n"
        ++ "      expected: * -> " ++ constraintStr ++ "\n"
        ++ "      given: " ++ show (ppr (mkClassKind $ drop (length classArgs) classVars))

{-|
  Given some 'Info' about a class, get its methods as 'SigD' declarations.
-}
classMethods :: MonadFail m => Info -> m [Dec]
classMethods (ClassI (ClassD _ _ _ _ methods) _) = return methods
classMethods other = fail $ "classMethods: expected a class name, given " ++ show other

{-|
  Helper for applying `methodToField` over multiple methods using the same name
  replacement for a particular typeclass.
-}
methodsToFields :: MonadFail m => Name -> Type -> [Dec] -> m [VarStrictType]
methodsToFields name typ = mapM (methodToField name typ)

{-|
  Converts a typeclass’s method (represented as a 'SigD') to a record field.
  There are two operations involved in this conversion:

    1. Prepend the name with the @_@ character to avoid name clashes. This is
       performed by 'methodNameToFieldName'.

    2. Replace the type variable bound by the typeclass constraint. To explain
       this step, consider the following typeclass:

       > class HasFoo x where
       >   foo :: x -> Foo

       The signature for the @foo@ class is actually as follows:

       > forall x. HasFoo x => x -> Foo

       However, when converted into a record, we want it to look like this:

       > data Record x = Record { fFoo :: x -> Foo }

       Specifically, we want to remove the @forall@ constraint, and we need
       to replace the type variable bound by the typeclass constraint with the
       type variable bound by the record declaration itself.

       To accomplish this, 'methodToField' accepts a 'Name' and a 'Type', where
       the 'Name' is the name of a replacement type variable, and the 'Type'
       is the typeclass whose constraint must be removed.
-}
methodToField :: MonadFail m => Name -> Type -> Dec -> m VarStrictType
methodToField mVar classT (SigD name typ) = (fieldName, noStrictness,) <$> newT
  where fieldName = methodNameToFieldName name
        newT = replaceClassConstraint classT mVar typ
methodToField _ _ _ = fail "methodToField: internal error; report a bug with the test-fixture package"

{-|
  Prepends a name with a @_@ or @~@ character (depending on whether or not the
  name refers to an infix operator) to avoid name clashes when generating record
  fields based on typeclass method names.
-}
methodNameToFieldName :: Name -> Name
methodNameToFieldName name = mkName (prefixChar : nameBase name)
  where isInfixChar c = (c `notElem` "_:\"'") && (isPunctuation c || isSymbol c)
        nameIsInfix = isInfixChar . head $ nameBase name
        prefixChar = if nameIsInfix then '~' else '_'

{-|
  Implements the class constraint replacement functionality as described in the
  documentation for 'methodToField'. Given a type that represents the typeclass
  whose constraint must be removed and a name used to replace the constrained
  type variable, it replaces the uses of that type variable everywhere in the
  quantified type and removes the constraint.
-}
replaceClassConstraint :: MonadFail m => Type -> Name -> Type -> m Type
replaceClassConstraint classType freeVar (ForallT vars preds typ) =
  let -- split the provided class into the typeclass and its arguments:
      --
      --             MonadFoo Int Bool
      --             ^^^^^^^^ ^^^^^^^^
      --                 |       |
      --  unappliedClassType   classTypeArgs
      unappliedClassType = unappliedType classType
      classTypeArgs = typeArgs classType

      -- find the constraint that belongs to the typeclass by searching for the
      -- constaint with the same base type
      ([replacedPred], newPreds) = partition ((unappliedClassType ==) . unappliedType) preds

      -- Get the type vars that we need to replace, and match them with their
      -- replacements. Since we have already validated that classType is the
      -- same as replacedPred but missing one argument (via
      -- assertDerivableConstraint), we can easily align the types we need to
      -- replace with their instantiations.
      replacedVars = typeVarNames replacedPred
      replacementTypes = classTypeArgs ++ [VarT freeVar]

      -- get the remaining vars in the forall quantification after stripping out
      -- the ones we’re replacing
      newVars = filter ((`notElem` replacedVars) . tyVarBndrName) vars

      -- actually perform the replacement substitution for each type var and its replacement
      replacedT = foldl' (flip $ uncurry substituteTypeVar) typ (zip replacedVars replacementTypes)
  in return $ ForallT newVars newPreds replacedT
replaceClassConstraint _ _ _ = fail "replaceClassConstraint: internal error; report a bug with the test-fixture package"

{-|
  Substitutes a type variable with a type within a particular type. This is used
  by 'replaceClassConstraint' to swap out the constrained and quantified type
  variable with the type variable bound within the record declaration.
-}
substituteTypeVar :: Name -> Type -> Type -> Type
substituteTypeVar initial replacement = doReplace
  where doReplace (ForallT a b t) = ForallT a b (doReplace t)
        doReplace (AppT a b) = AppT (doReplace a) (doReplace b)
        doReplace (SigT t k) = SigT (doReplace t) k
        doReplace t@(VarT n)
          | n == initial = replacement
          | otherwise    = t
        doReplace other = other

{-|
  Given a record field name, produces a 'FieldExp' that assigns that field to
  a function defined in terms of 'unimplemented', which will raise an error
  upon an attempt to invoke it that will contain a message that explains the
  method has not been implemented by a user.
-}
unimplementedField :: Name -> FieldExp
unimplementedField fieldName = (fieldName, unimplementedE)
  where unimplementedE = AppE (VarE 'unimplemented) (LitE (StringL $ nameBase fieldName))

{-|
  Generates an implementation of a method within a 'Fixie' typeclass
  instance for a generated fixture record. The implementation handles four
  things:

    1. It detects the arity of the method to implement and automatically creates
       a function declaration that accepts that many arguments.

    2. It retrieves the actual implementation out of the reader-provided
       typeclass dictionary using 'getFunction'.

    3. It captures the call of the function.

    4. It applies the reader-provided function to all of the arguments generated
       by the arity-detection pass from step 1.

   This function expects a signature declaration that describes the typeclass
   method to generate an implementation for, and it returns the function
   definition as a declaration.
-}
mkDictInstanceFunc :: Dec -> Q Dec
mkDictInstanceFunc (SigD name typ) = do
  let arity = functionTypeArity typ

  argNames <- replicateM arity (newName "x")
  let pats = map VarP argNames

  let askFunc = VarE (methodNameToFieldName name)
  let nameString = LitE (StringL (nameBase name))
  let vars = map VarE argNames

  implE <- [e|do
    fn <- getFunction $(return askFunc)
    let fnString = $(return nameString)
    let call = Call $ Function (pack fnString)
    captureCall call
    $(return $ applyE (VarE 'fn) vars)
   |]

  let funClause = Clause pats (NormalB implE) []
  return $ FunD name [funClause]
mkDictInstanceFunc other = fail $ "mkDictInstanceFunc: expected method signature, given " ++ show other

{-|
  Given a potentially applied type, like @T a b@, returns the base, unapplied
  type name, like @T@.
-}
unappliedType :: Type -> Type
unappliedType t@ConT{} = t
unappliedType (AppT t _) = unappliedType t
unappliedType other = error $ "expected plain applied type, given " ++ show other

{-|
  Like 'unappliedType', but extracts the 'Name' instead of 'Type'.
-}
unappliedTypeName :: Type -> Name
unappliedTypeName t = let (ConT name) = unappliedType t in name

{-|
  The inverse of 'unappliedType', this gets the arguments a type is applied to.
-}
typeArgs :: Type -> [Type]
typeArgs (AppT t a) = typeArgs t ++ [a]
typeArgs _          = []

{-|
  Given a type, returns a list of all of the unique type variables contained
  within it.
-}
typeVarNames :: Type -> [Name]
typeVarNames (VarT n) = [n]
typeVarNames (AppT a b) = nub (typeVarNames a ++ typeVarNames b)
typeVarNames _ = []

{-|
  Given any arbitrary 'TyVarBndr', gets its 'Name'.
-}
tyVarBndrName :: TyVarBndr -> Name
tyVarBndrName (PlainTV name) = name
tyVarBndrName (KindedTV name _) = name

{-|
  Given any arbitrary 'Type', gets its function arity as a 'Int'. Non-function
  types have arity @0@.

  >>> functionTypeArity [t|()|]
  0
  >>> functionTypeArity [t|() -> ()|]
  1
  >>> functionTypeArity [t|() -> () -> ()|]
  2
-}
functionTypeArity :: Type -> Int
functionTypeArity (AppT (AppT ArrowT _) b) = 1 + functionTypeArity b
functionTypeArity (ForallT _ _ typ) = functionTypeArity typ
functionTypeArity _ = 0

{-|
  Given an 'Exp' that represents a function value and a list of 'Exp's that
  represent function arguments, produces a new 'Exp' that applies the function
  to the provided arguments.
-}
applyE :: Exp -> [Exp] -> Exp
applyE = foldl' AppE

{------------------------------------------------------------------------------|
| The following definitions abstract over differences in base and              |
| template-haskell between GHC versions. This allows the same code to work     |
| without writing CPP everywhere and ending up with a small mess.              |
|------------------------------------------------------------------------------}

type MonadFail = Fail.MonadFail

mkInstanceD :: Cxt -> Type -> [Dec] -> Dec
mkInstanceD = InstanceD Nothing

mkDataD :: Cxt -> Name -> [TyVarBndr] -> [Con] -> Dec
mkDataD a b c d = DataD a b c Nothing d []

noStrictness :: Bang
noStrictness = Bang NoSourceUnpackedness NoSourceStrictness