{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE TemplateHaskell #-}

-- Helpers for TemplateHaskell instance generators

module Hyper.TH.Internal.Utils
    ( -- Internals for use in TH for sub-classes
      TypeInfo (..)
    , TypeContents (..)
    , CtrTypePattern (..)
    , NodeWitnesses (..)
    , makeTypeInfo
    , makeNodeOf
    , parts
    , toTuple
    , matchType
    , niceName
    , mkNiceTypeName
    , applicativeStyle
    , unapply
    , getVar
    , makeConstructorVars
    , consPat
    , simplifyContext
    , childrenTypes
    ) where

import qualified Control.Lens as Lens
import Control.Monad.Trans.Class (MonadTrans (..))
import Control.Monad.Trans.State (State, evalState, execStateT, gets, modify)
import qualified Data.Char as Char
import Data.List (intercalate, nub)
import qualified Data.Map as Map
import Generic.Data (Generically (..))
import Hyper.Class.Nodes (HWitness (..))
import Hyper.Type (AHyperType (..), GetHyperType, type (:#))
import Language.Haskell.TH
import qualified Language.Haskell.TH.Datatype as D
import Language.Haskell.TH.Datatype.TyVarBndr

import Hyper.Internal.Prelude

data TypeInfo = TypeInfo
    { TypeInfo -> Name
tiName :: Name
    , TypeInfo -> Type
tiInstance :: Type
    , TypeInfo -> [TyVarBndrUnit]
tiParams :: [TyVarBndrUnit]
    , TypeInfo -> Name
tiHyperParam :: Name
    , TypeInfo
-> [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors :: [(Name, D.ConstructorVariant, [Either Type CtrTypePattern])]
    }
    deriving (Int -> TypeInfo -> ShowS
[TypeInfo] -> ShowS
TypeInfo -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [TypeInfo] -> ShowS
$cshowList :: [TypeInfo] -> ShowS
show :: TypeInfo -> [Char]
$cshow :: TypeInfo -> [Char]
showsPrec :: Int -> TypeInfo -> ShowS
$cshowsPrec :: Int -> TypeInfo -> ShowS
Show)

data TypeContents = TypeContents
    { TypeContents -> Set Type
tcChildren :: Set Type
    , TypeContents -> Set Type
tcEmbeds :: Set Type
    , TypeContents -> Set Type
tcOthers :: Set Type
    }
    deriving (Int -> TypeContents -> ShowS
[TypeContents] -> ShowS
TypeContents -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [TypeContents] -> ShowS
$cshowList :: [TypeContents] -> ShowS
show :: TypeContents -> [Char]
$cshow :: TypeContents -> [Char]
showsPrec :: Int -> TypeContents -> ShowS
$cshowsPrec :: Int -> TypeContents -> ShowS
Show, forall x. Rep TypeContents x -> TypeContents
forall x. TypeContents -> Rep TypeContents x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep TypeContents x -> TypeContents
$cfrom :: forall x. TypeContents -> Rep TypeContents x
Generic)
    deriving (NonEmpty TypeContents -> TypeContents
TypeContents -> TypeContents -> TypeContents
forall b. Integral b => b -> TypeContents -> TypeContents
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
stimes :: forall b. Integral b => b -> TypeContents -> TypeContents
$cstimes :: forall b. Integral b => b -> TypeContents -> TypeContents
sconcat :: NonEmpty TypeContents -> TypeContents
$csconcat :: NonEmpty TypeContents -> TypeContents
<> :: TypeContents -> TypeContents -> TypeContents
$c<> :: TypeContents -> TypeContents -> TypeContents
Semigroup, Semigroup TypeContents
TypeContents
[TypeContents] -> TypeContents
TypeContents -> TypeContents -> TypeContents
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
mconcat :: [TypeContents] -> TypeContents
$cmconcat :: [TypeContents] -> TypeContents
mappend :: TypeContents -> TypeContents -> TypeContents
$cmappend :: TypeContents -> TypeContents -> TypeContents
mempty :: TypeContents
$cmempty :: TypeContents
Monoid) via Generically TypeContents

data CtrTypePattern
    = Node Type
    | FlatEmbed TypeInfo
    | GenEmbed Type
    | InContainer Type CtrTypePattern
    deriving (Int -> CtrTypePattern -> ShowS
[CtrTypePattern] -> ShowS
CtrTypePattern -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [CtrTypePattern] -> ShowS
$cshowList :: [CtrTypePattern] -> ShowS
show :: CtrTypePattern -> [Char]
$cshow :: CtrTypePattern -> [Char]
showsPrec :: Int -> CtrTypePattern -> ShowS
$cshowsPrec :: Int -> CtrTypePattern -> ShowS
Show)

makeTypeInfo :: Name -> Q TypeInfo
makeTypeInfo :: Name -> Q TypeInfo
makeTypeInfo Name
name =
    do
        DatatypeInfo
info <- Name -> Q DatatypeInfo
D.reifyDatatype Name
name
        (Type
dst, Name
var) <- DatatypeInfo -> Q (Type, Name)
parts DatatypeInfo
info
        let makeCons :: ConstructorInfo
-> Q (Name, ConstructorVariant, [Either Type CtrTypePattern])
makeCons ConstructorInfo
c =
                forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Name -> Name -> Type -> Q (Either Type CtrTypePattern)
matchType Name
name Name
var) (ConstructorInfo -> [Type]
D.constructorFields ConstructorInfo
c)
                    forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (ConstructorInfo -> Name
D.constructorName ConstructorInfo
c,ConstructorInfo -> ConstructorVariant
D.constructorVariant ConstructorInfo
c,)
        [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
cons <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ConstructorInfo
-> Q (Name, ConstructorVariant, [Either Type CtrTypePattern])
makeCons (DatatypeInfo -> [ConstructorInfo]
D.datatypeCons DatatypeInfo
info)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure
            TypeInfo
                { tiName :: Name
tiName = Name
name
                , tiInstance :: Type
tiInstance = Type
dst
                , tiParams :: [TyVarBndrUnit]
tiParams = DatatypeInfo -> [TyVarBndrUnit]
D.datatypeVars DatatypeInfo
info forall a b. a -> (a -> b) -> b
& forall a. [a] -> [a]
init
                , tiHyperParam :: Name
tiHyperParam = Name
var
                , tiConstructors :: [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors = [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
cons
                }

parts :: D.DatatypeInfo -> Q (Type, Name)
parts :: DatatypeInfo -> Q (Type, Name)
parts DatatypeInfo
info =
    case DatatypeInfo -> [TyVarBndrUnit]
D.datatypeVars DatatypeInfo
info of
        [] -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"expected type constructor which requires arguments"
        [TyVarBndrUnit]
xs ->
            forall r flag.
(Name -> r) -> (Name -> Type -> r) -> TyVarBndr_ flag -> r
elimTV
                (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,) Type
res)
                ( \Name
var Type
c ->
                    case Type
c of
                        ConT Name
aHyper | Name
aHyper forall a. Eq a => a -> a -> Bool
== ''AHyperType -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
res, Name
var)
                        Type
_ -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"expected last argument to be a AHyperType variable"
                )
                (forall a. [a] -> a
last [TyVarBndrUnit]
xs)
            where
                res :: Type
res =
                    forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (DatatypeInfo -> Name
D.datatypeName DatatypeInfo
info)) (forall a. [a] -> [a]
init [TyVarBndrUnit]
xs forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Name -> Type
VarT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall flag. TyVarBndr_ flag -> Name
D.tvName)

childrenTypes :: TypeInfo -> TypeContents
childrenTypes :: TypeInfo -> TypeContents
childrenTypes TypeInfo
info = forall s a. State s a -> s -> a
evalState (TypeInfo -> State (Set Type) TypeContents
childrenTypesH TypeInfo
info) forall a. Monoid a => a
mempty

childrenTypesH ::
    TypeInfo -> State (Set Type) TypeContents
childrenTypesH :: TypeInfo -> State (Set Type) TypeContents
childrenTypesH TypeInfo
info =
    do
        Bool
did <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (forall s a. s -> Getting a s a -> a
^. forall m. Contains m => Index m -> Lens' m Bool
Lens.contains (TypeInfo -> Type
tiInstance TypeInfo
info))
        if Bool
did
            then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
            else
                forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (forall m. Contains m => Index m -> Lens' m Bool
Lens.contains (TypeInfo -> Type
tiInstance TypeInfo
info) forall s t a b. ASetter s t a b -> b -> s -> t
.~ Bool
True)
                    forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse CtrTypePattern -> State (Set Type) TypeContents
addPat (TypeInfo
-> [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors TypeInfo
info forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field3 s t a b => Lens s t a b
Lens._3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall c a b. Prism (Either c a) (Either c b) a b
Lens._Right)
                    forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a. Monoid a => [a] -> a
mconcat
    where
        addPat :: CtrTypePattern -> State (Set Type) TypeContents
addPat (FlatEmbed TypeInfo
inner) = TypeInfo -> State (Set Type) TypeContents
childrenTypesH TypeInfo
inner
        addPat (Node Type
x) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty{tcChildren :: Set Type
tcChildren = forall a. Monoid a => a
mempty forall a b. a -> (a -> b) -> b
& forall m. Contains m => Index m -> Lens' m Bool
Lens.contains Type
x forall s t a b. ASetter s t a b -> b -> s -> t
.~ Bool
True}
        addPat (GenEmbed Type
x) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty{tcEmbeds :: Set Type
tcEmbeds = forall a. Monoid a => a
mempty forall a b. a -> (a -> b) -> b
& forall m. Contains m => Index m -> Lens' m Bool
Lens.contains Type
x forall s t a b. ASetter s t a b -> b -> s -> t
.~ Bool
True}
        addPat (InContainer Type
_ CtrTypePattern
x) = CtrTypePattern -> State (Set Type) TypeContents
addPat CtrTypePattern
x

unapply :: Type -> (Type, [Type])
unapply :: Type -> (Type, [Type])
unapply =
    [Type] -> Type -> (Type, [Type])
go []
    where
        go :: [Type] -> Type -> (Type, [Type])
go [Type]
as (SigT Type
x Type
_) = [Type] -> Type -> (Type, [Type])
go [Type]
as Type
x
        go [Type]
as (AppT Type
f Type
a) = [Type] -> Type -> (Type, [Type])
go (Type
a forall a. a -> [a] -> [a]
: [Type]
as) Type
f
        go [Type]
as Type
x = (Type
x, [Type]
as)

matchType :: Name -> Name -> Type -> Q (Either Type CtrTypePattern)
matchType :: Name -> Name -> Type -> Q (Either Type CtrTypePattern)
matchType Name
_ Name
var (ConT Name
get `AppT` VarT Name
h `AppT` (PromotedT Name
aHyper `AppT` Type
x))
    | Name
get forall a. Eq a => a -> a -> Bool
== ''GetHyperType Bool -> Bool -> Bool
&& Name
aHyper forall a. Eq a => a -> a -> Bool
== 'AHyperType Bool -> Bool -> Bool
&& Name
h forall a. Eq a => a -> a -> Bool
== Name
var =
        Type -> CtrTypePattern
Node Type
x forall a b. a -> (a -> b) -> b
& forall a b. b -> Either a b
Right forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
matchType Name
_ Name
var (InfixT (VarT Name
h) Name
hash Type
x)
    | Name
hash forall a. Eq a => a -> a -> Bool
== ''(:#) Bool -> Bool -> Bool
&& Name
h forall a. Eq a => a -> a -> Bool
== Name
var =
        Type -> CtrTypePattern
Node Type
x forall a b. a -> (a -> b) -> b
& forall a b. b -> Either a b
Right forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
matchType Name
_ Name
var (ConT Name
hash `AppT` VarT Name
h `AppT` Type
x)
    | Name
hash forall a. Eq a => a -> a -> Bool
== ''(:#) Bool -> Bool -> Bool
&& Name
h forall a. Eq a => a -> a -> Bool
== Name
var =
        Type -> CtrTypePattern
Node Type
x forall a b. a -> (a -> b) -> b
& forall a b. b -> Either a b
Right forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
matchType Name
top Name
var (Type
x `AppT` VarT Name
h)
    | Name
h forall a. Eq a => a -> a -> Bool
== Name
var Bool -> Bool -> Bool
&& Type
x forall a. Eq a => a -> a -> Bool
/= Name -> Type
ConT ''GetHyperType =
        case Type -> (Type, [Type])
unapply Type
x of
            (ConT Name
c, [Type]
args) | Name
c forall a. Eq a => a -> a -> Bool
/= Name
top ->
                do
                    DatatypeInfo
inner <- Name -> Q DatatypeInfo
D.reifyDatatype Name
c
                    let innerVars :: [Name]
innerVars = DatatypeInfo -> [TyVarBndrUnit]
D.datatypeVars DatatypeInfo
inner forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall flag. TyVarBndr_ flag -> Name
D.tvName
                    let subst :: Map Name Type
subst =
                            [Type]
args forall a. Semigroup a => a -> a -> a
<> [Name -> Type
VarT Name
var]
                                forall a b. a -> (a -> b) -> b
& forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
innerVars
                                forall a b. a -> (a -> b) -> b
& forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
                    let makeCons :: ConstructorInfo
-> Q (Name, ConstructorVariant, [Either Type CtrTypePattern])
makeCons ConstructorInfo
i =
                            forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Name -> Name -> Type -> Q (Either Type CtrTypePattern)
matchType Name
top Name
var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TypeSubstitution a => Map Name Type -> a -> a
D.applySubstitution Map Name Type
subst) (ConstructorInfo -> [Type]
D.constructorFields ConstructorInfo
i)
                                forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (ConstructorInfo -> Name
D.constructorName ConstructorInfo
i,ConstructorInfo -> ConstructorVariant
D.constructorVariant ConstructorInfo
i,)
                    [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
cons <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ConstructorInfo
-> Q (Name, ConstructorVariant, [Either Type CtrTypePattern])
makeCons (DatatypeInfo -> [ConstructorInfo]
D.datatypeCons DatatypeInfo
inner)
                    if Name
var forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` ([Type] -> [TyVarBndrUnit]
D.freeVariablesWellScoped ([(Name, ConstructorVariant, [Either Type CtrTypePattern])]
cons forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field3 s t a b => Lens s t a b
Lens._3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a c b. Prism (Either a c) (Either b c) a b
Lens._Left) forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall flag. TyVarBndr_ flag -> Name
D.tvName)
                        then
                            TypeInfo -> CtrTypePattern
FlatEmbed
                                TypeInfo
                                    { tiName :: Name
tiName = Name
c
                                    , tiInstance :: Type
tiInstance = Type
x
                                    , tiParams :: [TyVarBndrUnit]
tiParams = DatatypeInfo -> [TyVarBndrUnit]
D.datatypeVars DatatypeInfo
inner forall a b. a -> (a -> b) -> b
& forall a. [a] -> [a]
init
                                    , tiHyperParam :: Name
tiHyperParam = Name
var
                                    , tiConstructors :: [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors = [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
cons
                                    }
                                forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
                        else Type -> CtrTypePattern
GenEmbed Type
x forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
            (Type, [Type])
_ -> Type -> CtrTypePattern
GenEmbed Type
x forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
            forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a b. b -> Either a b
Right
matchType Name
top Name
var x :: Type
x@(AppT Type
f Type
a) =
    -- TODO: check if applied over a functor-kinded type.
    Name -> Name -> Type -> Q (Either Type CtrTypePattern)
matchType Name
top Name
var Type
a
        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
            Left{} -> forall a b. a -> Either a b
Left Type
x
            Right CtrTypePattern
pat -> Type -> CtrTypePattern -> CtrTypePattern
InContainer Type
f CtrTypePattern
pat forall a b. a -> (a -> b) -> b
& forall a b. b -> Either a b
Right
matchType Name
_ Name
_ Type
t = forall a b. a -> Either a b
Left Type
t forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure

getVar :: Type -> Maybe Name
getVar :: Type -> Maybe Name
getVar (VarT Name
x) = forall a. a -> Maybe a
Just Name
x
getVar (SigT Type
x Type
_) = Type -> Maybe Name
getVar Type
x
getVar Type
_ = forall a. Maybe a
Nothing

toTuple :: Foldable t => t Type -> Type
toTuple :: forall (t :: * -> *). Foldable t => t Type -> Type
toTuple t Type
xs = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Int -> Type
TupleT (forall (t :: * -> *) a. Foldable t => t a -> Int
length t Type
xs)) t Type
xs

applicativeStyle :: Q Exp -> [Q Exp] -> Q Exp
applicativeStyle :: Q Exp -> [Q Exp] -> Q Exp
applicativeStyle Q Exp
f =
    forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {m :: * -> *}. Quote m => m Exp -> m Exp -> m Exp
ap [|pure $f|]
    where
        ap :: m Exp -> m Exp -> m Exp
ap m Exp
x m Exp
y = [|$x <*> $y|]

makeConstructorVars :: String -> [a] -> [(a, Name)]
makeConstructorVars :: forall a. [Char] -> [a] -> [(a, Name)]
makeConstructorVars [Char]
prefix [a]
fields =
    [Int
0 :: Int ..]
        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [Char] -> Name
mkName forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Char
'_' forall a. a -> [a] -> [a]
: [Char]
prefix) forall a. Semigroup a => a -> a -> a
<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show
        forall a b. a -> (a -> b) -> b
& forall a b. [a] -> [b] -> [(a, b)]
zip [a]
fields

consPat :: Name -> [(a, Name)] -> Q Pat
consPat :: forall a. Name -> [(a, Name)] -> Q Pat
consPat Name
c = forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
c forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall (m :: * -> *). Quote m => Name -> m Pat
varP forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)

simplifyContext :: [Pred] -> CxtQ
simplifyContext :: [Type] -> CxtQ
simplifyContext [Type]
preds =
    forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT ([Type] -> StateT (Set (Name, [Type]), Set Type) Q ()
goPreds [Type]
preds) (forall a. Monoid a => a
mempty :: Set (Name, [Type]), forall a. Monoid a => a
mempty :: Set Pred)
        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall s t a b. Field2 s t a b => Lens s t a b
Lens._2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
Lens.folded)
    where
        goPreds :: [Type] -> StateT (Set (Name, [Type]), Set Type) Q ()
goPreds = forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ((Type, [Type]) -> StateT (Set (Name, [Type]), Set Type) Q ()
go forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> (Type, [Type])
unapply)
        go :: (Type, [Type]) -> StateT (Set (Name, [Type]), Set Type) Q ()
go (Type
c, [VarT Name
v]) =
            -- Work-around reifyInstances returning instances for type variables
            -- by not checking.
            forall {b} {s} {m :: * -> *} {t :: * -> *}.
(Index b ~ Type, MonadState s m, Field2 s s b b, Contains b,
 Foldable t) =>
Type -> t Type -> m ()
yep Type
c [Name -> Type
VarT Name
v]
        go (ConT Name
c, [Type]
xs) =
            forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use (forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m. Contains m => Index m -> Lens' m Bool
Lens.contains (Name, [Type])
key)
                forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                    Bool
True -> forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- already checked
                    Bool
False ->
                        do
                            forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m. Contains m => Index m -> Lens' m Bool
Lens.contains (Name, [Type])
key forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= Bool
True
                            Name -> [Type] -> Q [InstanceDec]
reifyInstances Name
c [Type]
xs
                                forall a b. a -> (a -> b) -> b
& forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
                                forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                                    [InstanceD Maybe Overlap
_ [Type]
context Type
other [InstanceDec]
_] ->
                                        [Type] -> Q (Map Name Type)
D.unifyTypes [Type
other, forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
c) [Type]
xs]
                                            forall a b. a -> (a -> b) -> b
& forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
                                            forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (forall a. TypeSubstitution a => Map Name Type -> a -> a
`D.applySubstitution` [Type]
context)
                                            forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Type] -> StateT (Set (Name, [Type]), Set Type) Q ()
goPreds
                                    [InstanceDec]
_ -> forall {b} {s} {m :: * -> *} {t :: * -> *}.
(Index b ~ Type, MonadState s m, Field2 s s b b, Contains b,
 Foldable t) =>
Type -> t Type -> m ()
yep (Name -> Type
ConT Name
c) [Type]
xs
            where
                key :: (Name, [Type])
key = (Name
c, [Type]
xs)
        go (Type
c, [Type]
xs) = forall {b} {s} {m :: * -> *} {t :: * -> *}.
(Index b ~ Type, MonadState s m, Field2 s s b b, Contains b,
 Foldable t) =>
Type -> t Type -> m ()
yep Type
c [Type]
xs
        yep :: Type -> t Type -> m ()
yep Type
c t Type
xs = forall s t a b. Field2 s t a b => Lens s t a b
Lens._2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m. Contains m => Index m -> Lens' m Bool
Lens.contains (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT Type
c t Type
xs) forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= Bool
True

data NodeWitnesses = NodeWitnesses
    { NodeWitnesses -> Type -> Q Exp
nodeWit :: Type -> Q Exp
    , NodeWitnesses -> Type -> Q Exp
embedWit :: Type -> Q Exp
    , NodeWitnesses -> [Name]
nodeWitCtrs :: [Name]
    , NodeWitnesses -> [Name]
embedWitCtrs :: [Name]
    }

niceName :: Name -> String
niceName :: Name -> [Char]
niceName = forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Eq a => a -> a -> Bool
/= Char
'.') forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show

makeNodeOf :: TypeInfo -> ([Type -> Q Con], NodeWitnesses)
makeNodeOf :: TypeInfo -> ([Type -> Q Con], NodeWitnesses)
makeNodeOf TypeInfo
info =
    ( ([(Type, Name)]
nodes forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall {m :: * -> *}. Quote m => (Type, Name) -> Type -> m Con
nodeGadtType) forall a. Semigroup a => a -> a -> a
<> ([(Type, Name)]
embeds forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Type, Name) -> Type -> Q Con
embedGadtType)
    , NodeWitnesses
        { nodeWit :: Type -> Q Exp
nodeWit = [(Type, Name)]
nodes forall a b. a -> (a -> b) -> b
& forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall a b. a -> (a -> b) -> b
& Map Type Name -> Type -> Name
getWit forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Name
x -> [|HWitness $(conE x)|]
        , embedWit :: Type -> Q Exp
embedWit = [(Type, Name)]
embeds forall a b. a -> (a -> b) -> b
& forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall a b. a -> (a -> b) -> b
& Map Type Name -> Type -> Name
getWit forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Name
x -> [|HWitness . $(conE x)|]
        , nodeWitCtrs :: [Name]
nodeWitCtrs = [(Type, Name)]
nodes forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a b. (a, b) -> b
snd
        , embedWitCtrs :: [Name]
embedWitCtrs = [(Type, Name)]
embeds forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a b. (a, b) -> b
snd
        }
    )
    where
        niceTypeName :: [Char]
niceTypeName = TypeInfo -> Name
tiName TypeInfo
info forall a b. a -> (a -> b) -> b
& Name -> [Char]
niceName
        nodeBase :: [Char]
nodeBase = [Char]
"W_" forall a. Semigroup a => a -> a -> a
<> [Char]
niceTypeName forall a. Semigroup a => a -> a -> a
<> [Char]
"_"
        embedBase :: [Char]
embedBase = [Char]
"E_" forall a. Semigroup a => a -> a -> a
<> [Char]
niceTypeName forall a. Semigroup a => a -> a -> a
<> [Char]
"_"
        pats :: [Either Type CtrTypePattern]
pats = TypeInfo
-> [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors TypeInfo
info forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall s a. s -> Getting a s a -> a
^. forall s t a b. Field3 s t a b => Lens s t a b
Lens._3)
        nodes :: [(Type, Name)]
nodes =
            [Either Type CtrTypePattern]
pats forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall c a b. Prism (Either c a) (Either c b) a b
Lens._Right
                forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CtrTypePattern -> [Type]
nodesForPat
                forall a b. a -> (a -> b) -> b
& forall a. Eq a => [a] -> [a]
nub
                forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Type
t -> (Type
t, [Char] -> Name
mkName ([Char]
nodeBase forall a. Semigroup a => a -> a -> a
<> Type -> [Char]
mkNiceTypeName Type
t))
        nodesForPat :: CtrTypePattern -> [Type]
nodesForPat (Node Type
t) = [Type
t]
        nodesForPat (InContainer Type
_ CtrTypePattern
pat) = CtrTypePattern -> [Type]
nodesForPat CtrTypePattern
pat
        nodesForPat (FlatEmbed TypeInfo
x) = TypeInfo
-> [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors TypeInfo
x forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field3 s t a b => Lens s t a b
Lens._3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall c a b. Prism (Either c a) (Either c b) a b
Lens._Right forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CtrTypePattern -> [Type]
nodesForPat
        nodesForPat CtrTypePattern
_ = []
        nodeGadtType :: (Type, Name) -> Type -> m Con
nodeGadtType (Type
t, Name
n) Type
c = forall (m :: * -> *).
Quote m =>
[Name] -> [m StrictType] -> m Type -> m Con
gadtC [Name
n] [] (forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
c Type -> Type -> Type
`AppT` Type
t))
        embeds :: [(Type, Name)]
embeds =
            [Either Type CtrTypePattern]
pats forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall c a b. Prism (Either c a) (Either c b) a b
Lens._Right
                forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CtrTypePattern -> [Type]
embedsForPat
                forall a b. a -> (a -> b) -> b
& forall a. Eq a => [a] -> [a]
nub
                forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Type
t -> (Type
t, [Char] -> Name
mkName ([Char]
embedBase forall a. Semigroup a => a -> a -> a
<> Type -> [Char]
mkNiceTypeName Type
t))
        embedsForPat :: CtrTypePattern -> [Type]
embedsForPat (GenEmbed Type
t) = [Type
t]
        embedsForPat (InContainer Type
_ CtrTypePattern
pat) = CtrTypePattern -> [Type]
embedsForPat CtrTypePattern
pat
        embedsForPat (FlatEmbed TypeInfo
x) = TypeInfo
-> [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors TypeInfo
x forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field3 s t a b => Lens s t a b
Lens._3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall c a b. Prism (Either c a) (Either c b) a b
Lens._Right forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CtrTypePattern -> [Type]
embedsForPat
        embedsForPat CtrTypePattern
_ = []
        embedGadtType :: (Type, Name) -> Type -> Q Con
embedGadtType (Type
t, Name
n) Type
c =
            forall (m :: * -> *).
Quote m =>
[Name] -> [m StrictType] -> m Type -> m Con
gadtC
                [Name
n]
                [ forall (m :: * -> *). Quote m => m Bang -> m Type -> m StrictType
bangType
                    (forall (m :: * -> *).
Quote m =>
m SourceUnpackedness -> m SourceStrictness -> m Bang
bang forall (m :: * -> *). Quote m => m SourceUnpackedness
noSourceUnpackedness forall (m :: * -> *). Quote m => m SourceStrictness
noSourceStrictness)
                    [t|HWitness $(pure t) $nodeVar|]
                ]
                [t|$(pure c) $nodeVar|]
        nodeVar :: Q Type
nodeVar = [Char] -> Name
mkName [Char]
"node" forall a b. a -> (a -> b) -> b
& forall (m :: * -> *). Quote m => Name -> m Type
varT
        getWit :: Map Type Name -> Type -> Name
        getWit :: Map Type Name -> Type -> Name
getWit Map Type Name
m Type
h =
            Map Type Name
m forall s a. s -> Getting (First a) s a -> Maybe a
^? forall m. Ixed m => Index m -> Traversal' m (IxValue m)
Lens.ix Type
h
                forall a b. a -> (a -> b) -> b
& forall a. a -> Maybe a -> a
fromMaybe (forall a. HasCallStack => [Char] -> a
error ([Char]
"Cant find witness for " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Type
h forall a. Semigroup a => a -> a -> a
<> [Char]
" in " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Map Type Name
m))

mkNiceTypeName :: Type -> String
mkNiceTypeName :: Type -> [Char]
mkNiceTypeName =
    forall a. [a] -> [[a]] -> [a]
intercalate [Char]
"_" forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [[Char]]
makeNiceType
    where
        makeNiceType :: Type -> [[Char]]
makeNiceType (ConT Name
x) =
            case Name -> [Char]
niceName Name
x of
                n :: [Char]
n@(Char
c : [Char]
_) | Char -> Bool
Char.isAlpha Char
c -> [[Char]
n]
                [Char]
_ -> [] -- Skip operators
        makeNiceType (AppT Type
x Type
y) = Type -> [[Char]]
makeNiceType Type
x forall a. Semigroup a => a -> a -> a
<> Type -> [[Char]]
makeNiceType Type
y
        makeNiceType (VarT Name
x) = [forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Eq a => a -> a -> Bool
/= Char
'_') (forall a. Show a => a -> [Char]
show Name
x)]
        makeNiceType (SigT Type
x Type
_) = Type -> [[Char]]
makeNiceType Type
x
        makeNiceType Type
x = forall a. HasCallStack => [Char] -> a
error ([Char]
"TODO: Witness name generator is partial! Need to support " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Type
x)