-- SPDX-FileCopyrightText: 2021 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

module Morley.Util.TH
  ( deriveGADTNFData
  , lookupTypeNameOrFail
  , isTypeAlias
  , addTypeVariables
  , tupT
  ) where

import Control.Monad.Fix (mfix)
import Language.Haskell.TH
import Prelude hiding (Type)

{-# ANN module ("HLint: ignore Language.Haskell.TH should be imported post-qualified or with an explicit import list" :: Text) #-}

-- | Generates an 'NFData' instance for a GADT.
--
-- On superclass constraints for type arguments:
-- we use heuristics to guess for which type arguments
-- we need to add @NFData@ instance.
-- If this behaves not as you want, probably it's just worth
-- starting passing the necessary constraints to this function manually.
deriveGADTNFData :: Name -> Q [Dec]
deriveGADTNFData :: Name -> Q [Dec]
deriveGADTNFData Name
name = do
  Exp
seqQ <- [| seq |]
  Exp
unit <- [| () |]
  (TyConI (DataD [Type]
_ Name
dataName [TyVarBndr ()]
vars Maybe Type
_ [Con]
cons [DerivClause]
_)) <- Name -> Q Info
reify Name
name
  [Role]
tyArgRoles <- Name -> Q [Role]
reifyRoles Name
name
  let
    nfDataC :: Type
nfDataC = Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"NFData"

    getNameFromVar :: TyVarBndr flag -> Name
getNameFromVar (PlainTV Name
n flag
_) = Name
n
    getNameFromVar (KindedTV Name
n flag
_ Type
_) = Name
n

    -- Unfolds multiple constructors of form "A, B, C :: A -> Stuff"
    -- into a list of tuples of constructor names and their data
    unfoldConstructor :: Con -> [(Name, [BangType])]
unfoldConstructor (GadtC [Name]
cs [BangType]
bangs Type
_) = (Name -> (Name, [BangType])) -> [Name] -> [(Name, [BangType])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map (,[BangType]
bangs) [Name]
cs
    unfoldConstructor (ForallC [TyVarBndr Specificity]
_ [Type]
_ Con
c) = Con -> [(Name, [BangType])]
unfoldConstructor Con
c
    unfoldConstructor Con
_ = String -> [(Name, [BangType])]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Non GADT constructors are not supported."

    -- Constructs a clause "rnf (ConName a1 a2 ...) = rnf a1 `seq` rnf a2 `seq` rnf a3 `seq` ..."
    makeClauses :: (Name, [BangType]) -> Q Clause
makeClauses (Name
conName, [BangType]
bangs) = do
        [Name]
varNames <- (BangType -> Q Name) -> [BangType] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\BangType
_ -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"a") [BangType]
bangs
        let rnfVar :: Exp
rnfVar = Name -> Exp
VarE 'rnf
        let rnfExp :: Name -> Exp
rnfExp = Exp -> Exp -> Exp
AppE Exp
rnfVar (Exp -> Exp) -> (Name -> Exp) -> Name -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp
VarE
        let infixSeq :: Exp -> Exp -> Exp
infixSeq Exp
e1 Exp
e2 = Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
e1) Exp
seqQ (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
e2)
        Clause -> Q Clause
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> Q Clause) -> Clause -> Q Clause
forall a b. (a -> b) -> a -> b
$
          ([Pat] -> Body -> [Dec] -> Clause
Clause
            [Name -> [Pat] -> Pat
ConP Name
conName ([Pat] -> Pat) -> [Pat] -> Pat
forall a b. (a -> b) -> a -> b
$ (Name -> Pat) -> [Name] -> [Pat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map Name -> Pat
VarP [Name]
varNames]
            (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ (Exp -> Element [Exp] -> Exp) -> Exp -> [Exp] -> Exp
forall t b. Container t => (b -> Element t -> b) -> b -> t -> b
foldl' Exp -> Exp -> Exp
Exp -> Element [Exp] -> Exp
infixSeq Exp
unit ((Name -> Exp) -> [Name] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map Name -> Exp
rnfExp [Name]
varNames))
            []
          )

    nfDataT :: Type
nfDataT =
      Type -> Type -> Type
AppT Type
nfDataC (Type -> Type) -> ([Type] -> Type) -> [Type] -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Element [Type] -> Type) -> Type -> [Type] -> Type
forall t b. Container t => (b -> Element t -> b) -> b -> t -> b
foldl' Type -> Type -> Type
Type -> Element [Type] -> Type
AppT (Name -> Type
ConT Name
dataName) ([Type] -> Type) -> [Type] -> Type
forall a b. (a -> b) -> a -> b
$
        (TyVarBndr () -> Type) -> [TyVarBndr ()] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map (Name -> Type
VarT (Name -> Type) -> (TyVarBndr () -> Name) -> TyVarBndr () -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr () -> Name
forall {flag}. TyVarBndr flag -> Name
getNameFromVar) [TyVarBndr ()]
vars

    nfDataConstr :: [Type]
nfDataConstr = do
      (TyVarBndr ()
var, Role
role) <- [TyVarBndr ()] -> [Role] -> [(TyVarBndr (), Role)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TyVarBndr ()]
vars [Role]
tyArgRoles

      -- Phantom type arguments do not require constraints
      case Role
role of
        Role
NominalR -> [()]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
        Role
RepresentationalR -> [()]
forall (f :: * -> *). Applicative f => f ()
pass
        Role
PhantomR -> [()]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
        Role
InferR -> Text -> [()]
forall a. HasCallStack => Text -> a
error Text
"unexpected InferR returned by reifyRole"

      -- Only types of 'Type' kind may require 'NFData' constraint
      Name
varTy <- case TyVarBndr ()
var of
        PlainTV Name
v ()
_ -> Name -> [Name]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
v
        KindedTV Name
v ()
_ Type
k -> do
          Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Type
k Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
StarT)
          pure Name
v

      return $ Type
nfDataC Type -> Type -> Type
`AppT` Name -> Type
VarT Name
varTy

    makeInstance :: [Clause] -> Dec
makeInstance [Clause]
clauses =
      Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [Type]
nfDataConstr Type
nfDataT [Name -> [Clause] -> Dec
FunD (String -> Name
mkName String
"rnf") [Clause]
clauses]

  [Clause]
clauses <- ((Name, [BangType]) -> Q Clause)
-> [(Name, [BangType])] -> Q [Clause]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Name, [BangType]) -> Q Clause
makeClauses ([(Name, [BangType])] -> Q [Clause])
-> [(Name, [BangType])] -> Q [Clause]
forall a b. (a -> b) -> a -> b
$ [Con]
cons [Con] -> (Con -> [(Name, [BangType])]) -> [(Name, [BangType])]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Con -> [(Name, [BangType])]
unfoldConstructor
  return [[Clause] -> Dec
makeInstance [Clause]
clauses]

lookupTypeNameOrFail :: String -> Q Name
lookupTypeNameOrFail :: String -> Q Name
lookupTypeNameOrFail String
typeStr =
  String -> Q (Maybe Name)
lookupTypeName String
typeStr Q (Maybe Name) -> (Maybe Name -> Q Name) -> Q Name
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe Name
Nothing -> String -> Q Name
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Name) -> String -> Q Name
forall a b. (a -> b) -> a -> b
$ String
"Failed type name lookup for: '" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
typeStr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"'."
    Just Name
tn -> Name -> Q Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
tn

-- | Check if name is a @type@
isTypeAlias :: Name -> Q Bool
isTypeAlias :: Name -> Q Bool
isTypeAlias Name
typeName = Name -> Q Info
reify Name
typeName Q Info -> (Info -> Bool) -> Q Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
  TyConI (TySynD {}) -> Bool
True
  Info
_ -> Bool
False

-- | Accepts a type constructor and fills it with variables until
-- getting a type of kind @*@.
addTypeVariables :: Name -> TypeQ
addTypeVariables :: Name -> TypeQ
addTypeVariables Name
tyCtor = do
  [TyVarBndr ()]
tyVarBindrs <- Name -> Q Info
reify Name
tyCtor Q Info -> (Info -> Q [TyVarBndr ()]) -> Q [TyVarBndr ()]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    TyConI (DataD [Type]
_ Name
_ [TyVarBndr ()]
tyVarBindrs Maybe Type
_ [Con]
_ [DerivClause]
_) -> [TyVarBndr ()] -> Q [TyVarBndr ()]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TyVarBndr ()]
tyVarBindrs
    TyConI (NewtypeD [Type]
_ Name
_ [TyVarBndr ()]
tyVarBindrs Maybe Type
_ Con
_ [DerivClause]
_) -> [TyVarBndr ()] -> Q [TyVarBndr ()]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TyVarBndr ()]
tyVarBindrs
    Info
_ -> String -> Q [TyVarBndr ()]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Expected a plain datatype"
  let vars :: [Name]
vars = [TyVarBndr ()]
tyVarBindrs [TyVarBndr ()] -> (TyVarBndr () -> Name) -> [Name]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
        PlainTV Name
vName ()
_ -> Name
vName
        KindedTV Name
vName ()
_ Type
_ -> Name
vName
  Type -> TypeQ
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TypeQ) -> Type -> TypeQ
forall a b. (a -> b) -> a -> b
$ (Type -> Element [Name] -> Type) -> Type -> [Name] -> Type
forall t b. Container t => (b -> Element t -> b) -> b -> t -> b
foldl (\Type
acc Element [Name]
var -> Type
acc Type -> Type -> Type
`AppT` Name -> Type
VarT Name
Element [Name]
var) (Name -> Type
ConT Name
tyCtor) [Name]
vars

-- | Given a list of types, produce the type of a tuple of
-- those types. This is analogous to 'tupE' and 'tupP'.
--
-- @
-- tupT [[t|Int|], [t|Char|], [t|Bool]] = [t| (Int, Char, Bool) |]
-- @
tupT :: [Q Type] -> Q Type
tupT :: [TypeQ] -> TypeQ
tupT [TypeQ]
ts = do
  -- We build the expression with a thunk inside that will be filled in with
  -- the length of the list once that's been determined. This works
  -- efficiently (in one pass) because TH.Type is rather lazy. Why isn't this
  -- just a left fold? A left fold will produce a big Q action that, when run,
  -- will produce the expression. We want to produce the expression incrementally
  -- as we run the Q action. foldM lets us do that, and mfix gives us the thunk
  -- for the tuple size. The irrefutable pattern is required as usual because the
  -- function passed to mfix must never force its argument.
  (Type
res, !Int
_n) <- ((Type, Int) -> Q (Type, Int)) -> Q (Type, Int)
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (\ ~(Type
_res, Int
n) -> ((Type, Int) -> TypeQ -> Q (Type, Int))
-> (Type, Int) -> [TypeQ] -> Q (Type, Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Type, Int) -> TypeQ -> Q (Type, Int)
forall {f :: * -> *} {b}.
(Functor f, Num b) =>
(Type, b) -> f Type -> f (Type, b)
go (Int -> Type
TupleT Int
n, Int
0) [TypeQ]
ts)
  Type -> TypeQ
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
res
  where
    go :: (Type, b) -> f Type -> f (Type, b)
go (Type
acc, !b
k) f Type
ty = do
      Type
ty' <- f Type
ty
      pure (Type
acc Type -> Type -> Type
`AppT` Type
ty', b
k b -> b -> b
forall a. Num a => a -> a -> a
+ b
1)