{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE TemplateHaskell #-}

-- | Generate 'HNodes' instances via @TemplateHaskell@
module Hyper.TH.Nodes
    ( makeHNodes
    ) where

import qualified Control.Lens as Lens
import GHC.Generics (V1)
import Hyper.Class.Nodes (HNodes (..), HWitness (..))
import Hyper.TH.Internal.Utils
import Language.Haskell.TH
import qualified Language.Haskell.TH.Datatype as D

import Hyper.Internal.Prelude

-- | Generate a 'HNodes' instance
makeHNodes :: Name -> DecsQ
makeHNodes :: Name -> DecsQ
makeHNodes Name
typeName = Name -> Q TypeInfo
makeTypeInfo Name
typeName forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TypeInfo -> DecsQ
makeHNodesForType

makeHNodesForType :: TypeInfo -> DecsQ
makeHNodesForType :: TypeInfo -> DecsQ
makeHNodesForType TypeInfo
info =
    [ forall (m :: * -> *).
Quote m =>
m [Type] -> m Type -> [m Dec] -> m Dec
instanceD
        ([Type] -> CxtQ
simplifyContext (TypeInfo -> [Type]
makeContext TypeInfo
info))
        [t|HNodes $(pure (tiInstance info))|]
        [ Name -> Maybe [Q (TyVarBndr ())] -> [Q Type] -> Q Type -> Q Dec
D.tySynInstDCompat
            ''HNodesConstraint
            (forall a. a -> Maybe a
Just [forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> TyVarBndr ()
plainTV Name
constraintVar)])
            [forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeInfo -> Type
tiInstance TypeInfo
info), Q Type
c]
            (CxtQ
nodesConstraint forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Type] -> CxtQ
simplifyContext forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall (t :: * -> *). Foldable t => t Type -> Type
toTuple)
        , Name -> Maybe [Q (TyVarBndr ())] -> [Q Type] -> Q Type -> Q Dec
D.tySynInstDCompat ''HWitnessType forall a. Maybe a
Nothing [forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeInfo -> Type
tiInstance TypeInfo
info)] Q Type
witType
        , Name -> Inline -> RuleMatch -> Phases -> Pragma
InlineP 'hLiftConstraint Inline
Inline RuleMatch
FunLike Phases
AllPhases forall a b. a -> (a -> b) -> b
& Pragma -> Dec
PragmaD forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
        , forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'hLiftConstraint (NodeWitnesses -> [Q Clause]
makeHLiftConstraints NodeWitnesses
wit)
        ]
    ]
        forall a. Semigroup a => a -> a -> a
<> [Q Dec]
witDecs
        forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA
    where
        (Q Type
witType, [Q Dec]
witDecs)
            | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Type -> Q Con]
nodeOfCons = ([t|V1|], [])
            | Bool
otherwise =
                ( TypeInfo -> [TyVarBndr ()]
tiParams TypeInfo
info forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall (m :: * -> *). Quote m => Name -> m Type
varT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall flag. TyVarBndr_ flag -> Name
D.tvName forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall (m :: * -> *). Quote m => m Type -> m Type -> m Type
appT (forall (m :: * -> *). Quote m => Name -> m Type
conT Name
witTypeName)
                ,
                    [ forall (m :: * -> *).
Quote m =>
m [Type]
-> Name
-> [TyVarBndr ()]
-> Maybe Type
-> [m Con]
-> [m DerivClause]
-> m Dec
dataD
                        (forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
                        Name
witTypeName
                        (TypeInfo -> [TyVarBndr ()]
tiParams TypeInfo
info forall a. Semigroup a => a -> a -> a
<> [Name -> TyVarBndr ()
plainTV (String -> Name
mkName String
"node")])
                        forall a. Maybe a
Nothing
                        ([Type -> Q Con]
nodeOfCons forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Q Type
witType forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=))
                        []
                    ]
                )
            where
                witTypeName :: Name
witTypeName = String -> Name
mkName (String
"W_" forall a. Semigroup a => a -> a -> a
<> Name -> String
niceName (TypeInfo -> Name
tiName TypeInfo
info))
        ([Type -> Q Con]
nodeOfCons, NodeWitnesses
wit) = TypeInfo -> ([Type -> Q Con], NodeWitnesses)
makeNodeOf TypeInfo
info
        constraintVar :: Name
constraintVar = String -> Name
mkName String
"constraint"
        c :: Q Type
c = forall (m :: * -> *). Quote m => Name -> m Type
varT Name
constraintVar
        contents :: TypeContents
contents = TypeInfo -> TypeContents
childrenTypes TypeInfo
info
        nodesConstraint :: CxtQ
nodesConstraint =
            (TypeContents -> Set Type
tcChildren TypeContents
contents forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
Lens.folded forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Q Type
c forall (m :: * -> *). Quote m => m Type -> m Type -> m Type
`appT`) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure)
                forall a. Semigroup a => a -> a -> a
<> (TypeContents -> Set Type
tcEmbeds TypeContents
contents forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
Lens.folded forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Type
x -> [t|HNodesConstraint $(pure x) $c|])
                forall a. Semigroup a => a -> a -> a
<> (TypeContents -> Set Type
tcOthers TypeContents
contents forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
Lens.folded forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall (f :: * -> *) a. Applicative f => a -> f a
pure)
                forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA

makeContext :: TypeInfo -> [Pred]
makeContext :: TypeInfo -> [Type]
makeContext TypeInfo
info =
    TypeInfo
-> [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors TypeInfo
info forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field3 s t a b => Lens s t a b
Lens._3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall c a b. Prism (Either c a) (Either c b) a b
Lens._Right forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CtrTypePattern -> [Type]
ctxForPat
    where
        ctxForPat :: CtrTypePattern -> [Type]
ctxForPat (InContainer Type
_ CtrTypePattern
pat) = CtrTypePattern -> [Type]
ctxForPat CtrTypePattern
pat
        ctxForPat (GenEmbed Type
t) = [Name -> Type
ConT ''HNodes Type -> Type -> Type
`AppT` Type
t]
        ctxForPat (FlatEmbed TypeInfo
t) = TypeInfo -> [Type]
makeContext TypeInfo
t
        ctxForPat CtrTypePattern
_ = []

makeHLiftConstraints :: NodeWitnesses -> [Q Clause]
makeHLiftConstraints :: NodeWitnesses -> [Q Clause]
makeHLiftConstraints NodeWitnesses
wit
    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Q Clause]
clauses = [forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [] (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|\case {}|]) []]
    | Bool
otherwise = [Q Clause]
clauses
    where
        clauses :: [Q Clause]
clauses = (NodeWitnesses -> [Name]
nodeWitCtrs NodeWitnesses
wit forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall {m :: * -> *}. Quote m => Name -> m Clause
liftNode) forall a. Semigroup a => a -> a -> a
<> (NodeWitnesses -> [Name]
embedWitCtrs NodeWitnesses
wit forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall {m :: * -> *}. Quote m => Name -> m Clause
liftEmbed)
        liftNode :: Name -> m Clause
liftNode Name
x = forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP 'HWitness [forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
x []]] (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|\_ r -> r|]) []
        liftEmbed :: Name -> m Clause
liftEmbed Name
x =
            forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause
                [forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP 'HWitness [forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
x [forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
witVar]]]
                (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|hLiftConstraint $(varE witVar)|])
                []
        witVar :: Name
        witVar :: Name
witVar = String -> Name
mkName String
"witness"