{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}

-- |
-- Module      :   Grisette.Core.TH
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.TH
  ( -- * Template Haskell procedures for building constructor wrappers
    makeUnionWrapper,
    makeUnionWrapper',
  )
where

import Control.Monad (join, replicateM, when, zipWithM)
import Grisette.Core.THCompat (augmentFinalType)
import Language.Haskell.TH
  ( Body (NormalB),
    Clause (Clause),
    Con (ForallC, GadtC, InfixC, NormalC, RecC, RecGadtC),
    Dec (DataD, FunD, NewtypeD, SigD),
    Exp (AppE, ConE, LamE, VarE),
    Info (DataConI, TyConI),
    Name,
    Pat (VarP),
    Q,
    Type (ForallT),
    mkName,
    newName,
    pprint,
    reify,
  )
import Language.Haskell.TH.Syntax (Name (Name), OccName (OccName))

-- | Generate constructor wrappers that wraps the result in a union-like monad with provided names.
--
-- > $(makeUnionWrapper' ["mrgTuple2"] ''(,))
--
-- generates
--
-- > mrgTuple2 :: (SymBoolOp bool, Monad u, Mergeable bool t1, Mergeable bool t2, MonadUnion bool u) => t1 -> t2 -> u (t1, t2)
-- > mrgTuple2 = \v1 v2 -> mrgSingle (v1, v2)
makeUnionWrapper' ::
  -- | Names for generated wrappers
  [String] ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeUnionWrapper' :: [String] -> Name -> Q [Dec]
makeUnionWrapper' [String]
names Name
typName = do
  [Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Con] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
constructors) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
  [[Dec]]
ds <- (String -> Con -> Q [Dec]) -> [String] -> [Con] -> Q [[Dec]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Con -> Q [Dec]
mkSingleWrapper [String]
names [Con]
constructors
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds

occName :: Name -> String
occName :: Name -> String
occName (Name (OccName String
name) NameFlavour
_) = String
name

getConstructorName :: Con -> Q String
getConstructorName :: Con -> Q String
getConstructorName (NormalC Name
name [BangType]
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecC Name
name [VarBangType]
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName InfixC {} =
  String -> Q String
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"You should use makeUnionWrapper' to manually provide the name for infix constructors"
getConstructorName (ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
c) = Con -> Q String
getConstructorName Con
c
getConstructorName (GadtC [Name
name] [BangType]
_ Type
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecGadtC [Name
name] [VarBangType]
_ Type
_) = String -> Q String
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName Con
c = String -> Q String
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor at this time: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Ppr a => a -> String
pprint Con
c

getConstructors :: Name -> Q [Con]
getConstructors :: Name -> Q [Con]
getConstructors Name
typName = do
  Info
d <- Name -> Q Info
reify Name
typName
  case Info
d of
    TyConI (DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ [Con]
constructors [DerivClause]
_) -> [Con] -> Q [Con]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Con]
constructors
    TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ Con
constructor [DerivClause]
_) -> [Con] -> Q [Con]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Con
constructor]
    Info
_ -> String -> Q [Con]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Con]) -> String -> Q [Con]
forall a b. (a -> b) -> a -> b
$ String
"Unsupported declaration: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Info -> String
forall a. Ppr a => a -> String
pprint Info
d

-- | Generate constructor wrappers that wraps the result in a union-like monad.
--
-- > $(makeUnionWrapper "mrg" ''Maybe)
--
-- generates
--
-- > mrgNothing :: (SymBoolOp bool, Monad u, Mergeable bool t, MonadUnion bool u) => u (Maybe t)
-- > mrgNothing = mrgSingle Nothing
-- > mrgJust :: (SymBoolOp bool, Monad u, Mergeable bool t, MonadUnion bool u) => t -> u (Maybe t)
-- > mrgJust = \x -> mrgSingle (Just x)
makeUnionWrapper ::
  -- | Prefix for generated wrappers
  String ->
  -- | The type to generate the wrappers for
  Name ->
  Q [Dec]
makeUnionWrapper :: String -> Name -> Q [Dec]
makeUnionWrapper String
prefix Name
typName = do
  [Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
  [String]
constructorNames <- (Con -> Q String) -> [Con] -> Q [String]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Con -> Q String
getConstructorName [Con]
constructors
  [String] -> Name -> Q [Dec]
makeUnionWrapper' ((String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String) -> [String] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
constructorNames) Name
typName

augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr Int
n Exp
f = do
  [Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
  let args :: [Pat]
args = (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs
  Exp
mrgSingleFun <- [|mrgSingle|]
  Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$
    [Pat] -> Exp -> Exp
LamE
      [Pat]
args
      ( Exp -> Exp -> Exp
AppE Exp
mrgSingleFun (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
          (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
f ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
xs)
      )

augmentNormalCType :: Type -> Q Type
augmentNormalCType :: Type -> Q Type
augmentNormalCType (ForallT [TyVarBndr Specificity]
tybinders Cxt
ctx Type
ty1) = do
  (([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
ty1
  Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT ([TyVarBndr Specificity]
bndrs [TyVarBndr Specificity]
-> [TyVarBndr Specificity] -> [TyVarBndr Specificity]
forall a. [a] -> [a] -> [a]
++ [TyVarBndr Specificity]
tybinders) (Cxt
preds Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ Cxt
ctx) Type
augmentedTyp
augmentNormalCType Type
t = do
  (([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
  Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
preds Type
augmentedTyp

mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper String
name (NormalC Name
oriName [BangType]
b) = do
  DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
  Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr ([BangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
b) (Name -> Exp
ConE Name
oriName)
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]
mkSingleWrapper String
name (RecC Name
oriName [VarBangType]
b) = do
  DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
  Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
  let retName :: Name
retName = String -> Name
mkName String
name
  Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr ([VarBangType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
b) (Name -> Exp
ConE Name
oriName)
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
      Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
    ]
mkSingleWrapper String
_ Con
v = String -> Q [Dec]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Dec]) -> String -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Ppr a => a -> String
pprint Con
v