-- SPDX-FileCopyrightText: 2021 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

module Morley.Util.TH
  ( deriveGADTNFData
  , lookupTypeNameOrFail
  , isTypeAlias
  , addTypeVariables
  ) where

import Language.Haskell.TH

{-# ANN module ("HLint: ignore Language.Haskell.TH should be imported post-qualified or with an explicit import list" :: Text) #-}

-- | Generates an NFData instance for a GADT. /Note:/ This will not generate
-- additional constraints to the generated instance if those are required.
deriveGADTNFData :: Name -> Q [Dec]
deriveGADTNFData :: Name -> Q [Dec]
deriveGADTNFData Name
name = do
  Exp
seqQ <- [| seq |]
  Exp
unit <- [| () |]
  (TyConI (DataD [Type]
_ Name
dataName [TyVarBndr ()]
vars Maybe Type
_ [Con]
cons [DerivClause]
_)) <- Name -> Q Info
reify Name
name
  let
    getNameFromVar :: TyVarBndr flag -> Name
getNameFromVar (PlainTV Name
n flag
_) = Name
n
    getNameFromVar (KindedTV Name
n flag
_ Type
_) = Name
n

    -- Unfolds multiple constructors of form "A, B, C :: A -> Stuff"
    -- into a list of tuples of constructor names and their data
    unfoldConstructor :: Con -> [(Name, [BangType])]
unfoldConstructor (GadtC [Name]
cs [BangType]
bangs Type
_) = (Name -> (Name, [BangType])) -> [Name] -> [(Name, [BangType])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map (,[BangType]
bangs) [Name]
cs
    unfoldConstructor (ForallC [TyVarBndr Specificity]
_ [Type]
_ Con
c) = Con -> [(Name, [BangType])]
unfoldConstructor Con
c
    unfoldConstructor Con
_ = String -> [(Name, [BangType])]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Non GADT constructors are not supported."

    -- Constructs a clause "rnf (ConName a1 a2 ...) = rnf a1 `seq` rnf a2 `seq` rnf a3 `seq` ..."
    makeClauses :: (Name, [BangType]) -> Q Clause
makeClauses (Name
conName, [BangType]
bangs) = do
        [Name]
varNames <- (BangType -> Q Name) -> [BangType] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\BangType
_ -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"a") [BangType]
bangs
        let rnfVar :: Exp
rnfVar = Name -> Exp
VarE 'rnf
        let rnfExp :: Name -> Exp
rnfExp = Exp -> Exp -> Exp
AppE Exp
rnfVar (Exp -> Exp) -> (Name -> Exp) -> Name -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Exp
VarE
        let infixSeq :: Exp -> Exp -> Exp
infixSeq Exp
e1 Exp
e2 = Maybe Exp -> Exp -> Maybe Exp -> Exp
InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
e1) Exp
seqQ (Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
e2)
        Clause -> Q Clause
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> Q Clause) -> Clause -> Q Clause
forall a b. (a -> b) -> a -> b
$
          ([Pat] -> Body -> [Dec] -> Clause
Clause
            [Name -> [Pat] -> Pat
ConP Name
conName ([Pat] -> Pat) -> [Pat] -> Pat
forall a b. (a -> b) -> a -> b
$ (Name -> Pat) -> [Name] -> [Pat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map Name -> Pat
VarP [Name]
varNames]
            (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ (Exp -> Element [Exp] -> Exp) -> Exp -> [Exp] -> Exp
forall t b. Container t => (b -> Element t -> b) -> b -> t -> b
foldl' Exp -> Exp -> Exp
Exp -> Element [Exp] -> Exp
infixSeq Exp
unit ((Name -> Exp) -> [Name] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map Name -> Exp
rnfExp [Name]
varNames))
            []
          )

    nfDataT :: Type
nfDataT =
      Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"NFData") (Type -> Type) -> ([Type] -> Type) -> [Type] -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Element [Type] -> Type) -> Type -> [Type] -> Type
forall t b. Container t => (b -> Element t -> b) -> b -> t -> b
foldl' Type -> Type -> Type
Type -> Element [Type] -> Type
AppT (Name -> Type
ConT Name
dataName) ([Type] -> Type) -> [Type] -> Type
forall a b. (a -> b) -> a -> b
$
        (TyVarBndr () -> Type) -> [TyVarBndr ()] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map (Name -> Type
VarT (Name -> Type) -> (TyVarBndr () -> Name) -> TyVarBndr () -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr () -> Name
forall {flag}. TyVarBndr flag -> Name
getNameFromVar) [TyVarBndr ()]
vars

    makeInstance :: [Clause] -> Dec
makeInstance [Clause]
clauses =
      Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [] Type
nfDataT [Name -> [Clause] -> Dec
FunD (String -> Name
mkName String
"rnf") [Clause]
clauses]


  [Clause]
clauses <- ((Name, [BangType]) -> Q Clause)
-> [(Name, [BangType])] -> Q [Clause]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Name, [BangType]) -> Q Clause
makeClauses ([(Name, [BangType])] -> Q [Clause])
-> [(Name, [BangType])] -> Q [Clause]
forall a b. (a -> b) -> a -> b
$ [Con]
cons [Con] -> (Con -> [(Name, [BangType])]) -> [(Name, [BangType])]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Con -> [(Name, [BangType])]
unfoldConstructor
  return [[Clause] -> Dec
makeInstance [Clause]
clauses]

lookupTypeNameOrFail :: String -> Q Name
lookupTypeNameOrFail :: String -> Q Name
lookupTypeNameOrFail String
typeStr =
  String -> Q (Maybe Name)
lookupTypeName String
typeStr Q (Maybe Name) -> (Maybe Name -> Q Name) -> Q Name
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe Name
Nothing -> String -> Q Name
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Name) -> String -> Q Name
forall a b. (a -> b) -> a -> b
$ String
"Failed type name lookup for: '" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
typeStr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"'."
    Just Name
tn -> Name -> Q Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
tn

-- | Check if name is a @type@
isTypeAlias :: Name -> Q Bool
isTypeAlias :: Name -> Q Bool
isTypeAlias Name
typeName = Name -> Q Info
reify Name
typeName Q Info -> (Info -> Bool) -> Q Bool
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
  TyConI (TySynD {}) -> Bool
True
  Info
_ -> Bool
False

-- | Accepts a type constructor and fills it with variables until
-- getting a type of kind @*@.
addTypeVariables :: Name -> TypeQ
addTypeVariables :: Name -> TypeQ
addTypeVariables Name
tyCtor = do
  [TyVarBndr ()]
tyVarBindrs <- Name -> Q Info
reify Name
tyCtor Q Info -> (Info -> Q [TyVarBndr ()]) -> Q [TyVarBndr ()]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    TyConI (DataD [Type]
_ Name
_ [TyVarBndr ()]
tyVarBindrs Maybe Type
_ [Con]
_ [DerivClause]
_) -> [TyVarBndr ()] -> Q [TyVarBndr ()]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TyVarBndr ()]
tyVarBindrs
    TyConI (NewtypeD [Type]
_ Name
_ [TyVarBndr ()]
tyVarBindrs Maybe Type
_ Con
_ [DerivClause]
_) -> [TyVarBndr ()] -> Q [TyVarBndr ()]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TyVarBndr ()]
tyVarBindrs
    Info
_ -> String -> Q [TyVarBndr ()]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Expected a plain datatype"
  let vars :: [Name]
vars = [TyVarBndr ()]
tyVarBindrs [TyVarBndr ()] -> (TyVarBndr () -> Name) -> [Name]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
        PlainTV Name
vName ()
_ -> Name
vName
        KindedTV Name
vName ()
_ Type
_ -> Name
vName
  Type -> TypeQ
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> TypeQ) -> Type -> TypeQ
forall a b. (a -> b) -> a -> b
$ (Type -> Element [Name] -> Type) -> Type -> [Name] -> Type
forall t b. Container t => (b -> Element t -> b) -> b -> t -> b
foldl (\Type
acc Element [Name]
var -> Type
acc Type -> Type -> Type
`AppT` Name -> Type
VarT Name
Element [Name]
var) (Name -> Type
ConT Name
tyCtor) [Name]
vars