{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}

-- |
-- Module      :   Grisette.Core.TH
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.TH
  ( -- * Template Haskell procedures for building constructor wrappers
    makeUnionWrapper,
    makeUnionWrapper',
  )
where

import Control.Monad (join, replicateM, when, zipWithM)
import Grisette.Core.THCompat (augmentFinalType)
import Language.Haskell.TH
  ( Body (NormalB),
    Clause (Clause),
    Con (ForallC, GadtC, InfixC, NormalC, RecC, RecGadtC),
    Dec (DataD, FunD, NewtypeD, SigD),
    Exp (AppE, ConE, LamE, VarE),
    Info (DataConI, TyConI),
    Name,
    Pat (VarP),
    Q,
    Type (ForallT),
    mkName,
    newName,
    pprint,
    reify,
  )
import Language.Haskell.TH.Syntax (Name (Name), OccName (OccName))

-- | Generate constructor wrappers that wraps the result in a union-like monad with provided names.
--
-- > $(makeUnionWrapper' ["mrgTuple2"] ''(,))
--
-- generates
--
-- > mrgTuple2 :: (SymBoolOp bool, Monad u, Mergeable bool t1, Mergeable bool t2, MonadUnion bool u) => t1 -> t2 -> u (t1, t2)
-- > mrgTuple2 = \v1 v2 -> mrgSingle (v1, v2)
makeUnionWrapper' ::
  -- | Names for generated wrappers
  [String] ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeUnionWrapper' :: [String] -> Name -> Q [Dec]
makeUnionWrapper' [String]
names Name
typName = do
  [Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
constructors) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
  [[Dec]]
ds <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Con -> Q [Dec]
mkSingleWrapper [String]
names [Con]
constructors
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds

occName :: Name -> String
occName :: Name -> String
occName (Name (OccName String
name) NameFlavour
_) = String
name

getConstructorName :: Con -> Q String
getConstructorName :: Con -> Q String
getConstructorName (NormalC Name
name [BangType]
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecC Name
name [VarBangType]
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName InfixC {} =
  forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"You should use makeUnionWrapper' to manually provide the name for infix constructors"
getConstructorName (ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
c) = Con -> Q String
getConstructorName Con
c
getConstructorName (GadtC [Name
name] [BangType]
_ Type
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecGadtC [Name
name] [VarBangType]
_ Type
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName Con
c = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor at this time: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Con
c

getConstructors :: Name -> Q [Con]
getConstructors :: Name -> Q [Con]
getConstructors Name
typName = do
  Info
d <- Name -> Q Info
reify Name
typName
  case Info
d of
    TyConI (DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ [Con]
constructors [DerivClause]
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return [Con]
constructors
    TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ Con
constructor [DerivClause]
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return [Con
constructor]
    Info
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported declaration: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Info
d

-- | Generate constructor wrappers that wraps the result in a union-like monad.
--
-- > $(makeUnionWrapper "mrg" ''Maybe)
--
-- generates
--
-- > mrgNothing :: (SymBoolOp bool, Monad u, Mergeable bool t, MonadUnion bool u) => u (Maybe t)
-- > mrgNothing = mrgSingle Nothing
-- > mrgJust :: (SymBoolOp bool, Monad u, Mergeable bool t, MonadUnion bool u) => t -> u (Maybe t)
-- > mrgJust = \x -> mrgSingle (Just x)
makeUnionWrapper ::
  -- | Prefix for generated wrappers
  String ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeUnionWrapper :: String -> Name -> Q [Dec]
makeUnionWrapper String
prefix Name
typName = do
  [Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
  [String]
constructorNames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Con -> Q String
getConstructorName [Con]
constructors
  [String] -> Name -> Q [Dec]
makeUnionWrapper' ((String
prefix forall a. [a] -> [a] -> [a]
++) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
constructorNames) Name
typName

augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr Int
n Exp
f = do
  [Name]
xs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
  let args :: [Pat]
args = forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs
  Exp
mrgSingleFun <- [|mrgSingle|]
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
    [Pat] -> Exp -> Exp
LamE
      [Pat]
args
      ( Exp -> Exp -> Exp
AppE Exp
mrgSingleFun forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
f (forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
xs)
      )

augmentNormalCType :: Type -> Q Type
augmentNormalCType :: Type -> Q Type
augmentNormalCType (ForallT [TyVarBndr Specificity]
tybinders Cxt
ctx Type
ty1) = do
  (([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
ty1
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT ([TyVarBndr Specificity]
bndrs forall a. [a] -> [a] -> [a]
++ [TyVarBndr Specificity]
tybinders) (Cxt
preds forall a. [a] -> [a] -> [a]
++ Cxt
ctx) Type
augmentedTyp
augmentNormalCType Type
t = do
  (([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
preds Type
augmentedTyp

mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper String
name (NormalC Name
oriName [BangType]
b) = do
  DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
  Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr (forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
b) (Name -> Exp
ConE Name
oriName)
  forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]
mkSingleWrapper String
name (RecC Name
oriName [VarBangType]
b) = do
  DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
  Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
b) (Name -> Exp
ConE Name
oriName)
  forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]
mkSingleWrapper String
_ Con
v = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor" forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Con
v