{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE CPP #-}

-- | Template-Haskell helpers for EADTs
module Haskus.Utils.EADT.TH
   ( eadtPattern
   , eadtInfixPattern
   , eadtPatternT
   , eadtInfixPatternT
   )
where

import Language.Haskell.TH
import Control.Monad
import Haskus.Utils.EADT

-- | Create a pattern synonym for an EADT constructor
--
-- E.g.
--
-- > data ConsF a e = ConsF a e deriving (Functor)
-- > $(eadtPattern 'ConsF "Cons")
-- >
-- > ====>
-- >
-- > pattern Cons :: ConsF a :<: xs => a -> EADT xs -> EADT xs
-- > pattern Cons a l = VF (ConsF a l)
--
eadtPattern
   :: Name       -- ^ Actual constructor (e.g., ConsF)
   -> String     -- ^ Name of the pattern (e.g., Cons)
   -> Q [Dec]
eadtPattern :: Name -> String -> Q [Dec]
eadtPattern Name
consName String
patStr = Name -> String -> Maybe (Q Type) -> Bool -> Q [Dec]
eadtPattern' Name
consName String
patStr Maybe (Q Type)
forall a. Maybe a
Nothing Bool
False

-- | Create an infix pattern synonym for an EADT constructor
--
-- E.g.
--
-- > data ConsF a e = ConsF a e deriving (Functor)
-- > $(eadtInfixPattern 'ConsF ":->")
-- >
-- > ====>
-- >
-- > pattern (:->) :: ConsF a :<: xs => a -> EADT xs -> EADT xs
-- > pattern a :-> l = VF (ConsF a l)
--
eadtInfixPattern
   :: Name       -- ^ Actual constructor (e.g., ConsF)
   -> String     -- ^ Name of the pattern (e.g., Cons)
   -> Q [Dec]
eadtInfixPattern :: Name -> String -> Q [Dec]
eadtInfixPattern Name
consName String
patStr = Name -> String -> Maybe (Q Type) -> Bool -> Q [Dec]
eadtPattern' Name
consName String
patStr Maybe (Q Type)
forall a. Maybe a
Nothing Bool
True

-- | Create a pattern synonym for an EADT constructor that is part of a
-- specified EADT.
--
-- This can be useful to help the type inference because instead of using a
-- generic "EADT xs" type, the pattern uses the provided type.
--
-- E.g.
--
-- > data ConsF a e = ConsF a e deriving (Functor)
-- > data NilF    e = NilF      deriving (Functor)
-- >
-- > type List a = EADT '[ConsF a, NilF]
-- >
-- > $(eadtPatternT 'ConsF "ConsList" [t|forall a. List a|])
-- >
-- > ====>
-- >
-- > pattern ConsList ::
-- >  ( List a ~ EADT xs
-- >  , ConsF a :<: xs
-- >  ) => a -> List a -> List a
-- > pattern ConsList a l = VF (ConsF a l)
--
-- Note that you have to quantify free variables explicitly with 'forall'
--
eadtPatternT
   :: Name       -- ^ Actual constructor (e.g., ConsF)
   -> String     -- ^ Name of the pattern (e.g., Cons)
   -> Q Type     -- ^ Type of the EADT (e.g., [t|forall a. List a|])
   -> Q [Dec]
eadtPatternT :: Name -> String -> Q Type -> Q [Dec]
eadtPatternT Name
consName String
patStr Q Type
qtype =
   Name -> String -> Maybe (Q Type) -> Bool -> Q [Dec]
eadtPattern' Name
consName String
patStr (Q Type -> Maybe (Q Type)
forall a. a -> Maybe a
Just Q Type
qtype) Bool
False

-- | Like `eadtPatternT` but generating an infix pattern synonym
eadtInfixPatternT
   :: Name       -- ^ Actual constructor (e.g., ConsF)
   -> String     -- ^ Name of the pattern (e.g., Cons)
   -> Q Type     -- ^ Type of the EADT (e.g., [t|forall a. List a|])
   -> Q [Dec]
eadtInfixPatternT :: Name -> String -> Q Type -> Q [Dec]
eadtInfixPatternT Name
consName String
patStr Q Type
qtype =
   Name -> String -> Maybe (Q Type) -> Bool -> Q [Dec]
eadtPattern' Name
consName String
patStr (Q Type -> Maybe (Q Type)
forall a. a -> Maybe a
Just Q Type
qtype) Bool
True

-- | Create a pattern synonym for an EADT constructor
eadtPattern'
   :: Name       -- ^ Actual constructor (e.g., ConsF)
   -> String     -- ^ Name of the pattern (e.g., Cons)
   -> Maybe (Q Type) -- ^ EADT type
   -> Bool       -- ^ Declare infix pattern
   -> Q [Dec]
eadtPattern' :: Name -> String -> Maybe (Q Type) -> Bool -> Q [Dec]
eadtPattern' Name
consName String
patStr Maybe (Q Type)
mEadtTy Bool
isInfix = do
   let patName :: Name
patName = String -> Name
mkName String
patStr

   Type
typ <- Name -> Q Info
reify Name
consName Q Info -> (Info -> Q Type) -> Q Type
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            DataConI Name
_ Type
t Name
_ -> Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
            Info
_              -> String -> Q Type
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Type) -> String -> Q Type
forall a b. (a -> b) -> a -> b
$ Name -> String
forall a. Show a => a -> String
show Name
consName String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" isn't a data constructor"

   case Type
typ of
      ForallT [TyVarBndr]
tvs Cxt
_ Type
tys -> do
         -- make pattern
         let getConArity :: Type -> Int
getConArity = \case
               AppT (AppT Type
ArrowT Type
_a) Type
b              -> Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Type -> Int
getConArity Type
b
#if MIN_VERSION_base(4,15,0)
               AppT (AppT (AppT MulArrowT _m) _a) b -> 1 + getConArity b
#endif
               Type
_                                    -> Int
0

             conArity :: Int
conArity = Type -> Int
getConArity Type
tys
         [Name]
conArgs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
conArity (String -> Q Name
newName String
"c")

         let vf :: Name
vf     = String -> Name
mkName String
"Haskus.Utils.EADT.VF"

         PatSynArgs
args <- if Bool -> Bool
not Bool
isInfix
            then PatSynArgs -> Q PatSynArgs
forall (m :: * -> *) a. Monad m => a -> m a
return ([Name] -> PatSynArgs
PrefixPatSyn [Name]
conArgs)
            else case [Name]
conArgs of
                  [Name
x,Name
y] -> PatSynArgs -> Q PatSynArgs
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Name -> PatSynArgs
InfixPatSyn Name
x Name
y)
                  [Name]
xs    -> String -> Q PatSynArgs
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q PatSynArgs) -> String -> Q PatSynArgs
forall a b. (a -> b) -> a -> b
$ String
"Infix pattern should have exactly two parameters (found " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
xs) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"

         let pat :: Dec
pat    = Name -> PatSynArgs -> PatSynDir -> Pat -> Dec
PatSynD Name
patName PatSynArgs
args PatSynDir
ImplBidir
#if MIN_VERSION_base(4,16,0)
                         -- handle new field for type-applications in patterns
                         (ConP vf [] [ConP consName [] (fmap VarP conArgs)])
#else
                         (Name -> [Pat] -> Pat
ConP Name
vf [Name -> [Pat] -> Pat
ConP Name
consName ((Name -> Pat) -> [Name] -> [Pat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Pat
VarP [Name]
conArgs)])
#endif

         let
            -- retrieve constructor type without the functor var
            -- e.g. ConsF a for ConsF a e
            getConTyp :: Type -> Type
getConTyp (AppT (AppT Type
ArrowT Type
_a) Type
b)              = Type -> Type
getConTyp Type
b
#if MIN_VERSION_base(4,15,0)
            getConTyp (AppT (AppT (AppT MulArrowT _m) _a) b) = getConTyp b
#endif
            getConTyp (AppT Type
a Type
_) = Type
a -- remove last AppT (functor var)
            getConTyp Type
_          = String -> Type
forall a. HasCallStack => String -> a
error String
"Invalid constructor type"

            conTyp :: Type
conTyp = Type -> Type
getConTyp Type
tys

            -- [* -> *]
            tyToTyList :: Type
tyToTyList = Type -> Type -> Type
AppT Type
ListT (Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
StarT) Type
StarT)

            -- retrieve functor var in "e"
#if MIN_VERSION_base(4,16,0)
            e = case last tvs of
              KindedTV nm _ _ -> nm
              PlainTV nm _    -> nm
#elif MIN_VERSION_base(4,15,0)
            KindedTV e _ StarT = last tvs
#else
            KindedTV Name
e Type
StarT = [TyVarBndr] -> TyVarBndr
forall a. [a] -> a
last [TyVarBndr]
tvs
#endif


         -- make pattern type
         ([TyVarBndr]
newTvs,Type
eadtTy,Cxt
ctx) <- do
            Name
xsName <- String -> Q Name
newName String
"xs"
            let
               xs :: Type
xs = Name -> Type
VarT Name
xsName
#if MIN_VERSION_base(4,15,0)
               xsTy = KindedTV xsName SpecifiedSpec tyToTyList
#else
               xsTy :: TyVarBndr
xsTy = Name -> Type -> TyVarBndr
KindedTV Name
xsName Type
tyToTyList
#endif
            Type
eadtXs <- [t| EADT $(return xs) |]

            Type
prd <-  [t| $(return conTyp) :<: $(return xs) |]
            Type
prd2 <-  [t| $(return (VarT e)) ~ $(return eadtXs) |]
            case Maybe (Q Type)
mEadtTy of
               Maybe (Q Type)
Nothing -> ([TyVarBndr], Type, Cxt) -> Q ([TyVarBndr], Type, Cxt)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TyVarBndr
xsTy],Type
eadtXs,[Type
prd,Type
prd2])
               Just Q Type
ty -> do
                  Type
ty' <- Q Type
ty
                  let ([TyVarBndr]
tvs',Type
ty'',Cxt
ctx') = case Type
ty' of
                        -- put freevars of the user specified type with the
                        -- other ones
                        ForallT [TyVarBndr]
tvs'' Cxt
ctx'' Type
t -> ([TyVarBndr]
tvs'',Type
t,Cxt
ctx'')
                        Type
_                     -> ([],Type
ty',[])
                  Type
prd3 <- [t| $(return ty'') ~ $(return eadtXs) |]
                  ([TyVarBndr], Type, Cxt) -> Q ([TyVarBndr], Type, Cxt)
forall (m :: * -> *) a. Monad m => a -> m a
return (TyVarBndr
xsTyTyVarBndr -> [TyVarBndr] -> [TyVarBndr]
forall a. a -> [a] -> [a]
:[TyVarBndr]
tvs',Type
ty'',Type
prdType -> Cxt -> Cxt
forall a. a -> [a] -> [a]
:Type
prd2Type -> Cxt -> Cxt
forall a. a -> [a] -> [a]
:Type
prd3Type -> Cxt -> Cxt
forall a. a -> [a] -> [a]
:Cxt
ctx')

         let
            -- remove functor var; add new type var
            tvs' :: [TyVarBndr]
tvs'       = [TyVarBndr]
tvs [TyVarBndr] -> [TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a] -> [a]
++ [TyVarBndr]
newTvs

            -- replace functor variable with EADT type
            go :: Type -> Type
go (AppT (AppT Type
ArrowT Type
a) Type
b)
               | VarT Name
v <- Type
a
               , Name
v Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
e      = Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
eadtTy) (Type -> Type
go Type
b)
               | Bool
otherwise   = Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
a)      (Type -> Type
go Type
b)
#if MIN_VERSION_base(4,15,0)
            go (AppT (AppT (AppT MulArrowT m) a) b)
               | VarT v <- a
               , v == e      = AppT (AppT (AppT MulArrowT m) eadtTy) (go b)
               | otherwise   = AppT (AppT (AppT MulArrowT m) a)      (go b)
#endif
            go Type
_             = Type
eadtTy
            t' :: Type
t'               = Type -> Type
go Type
tys


         let sig :: Dec
sig = Name -> Type -> Dec
PatSynSigD Name
patName ([TyVarBndr] -> Cxt -> Type -> Type
ForallT [TyVarBndr]
tvs' Cxt
ctx Type
t')

         [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Dec
sig,Dec
pat]

      Type
_ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Dec]) -> String -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ Name -> String
forall a. Show a => a -> String
show Name
consName String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'s type doesn't have a free variable, it can't be a functor"