{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskell #-}

module Hyper.TH.Morph
    ( makeHMorph
    ) where

import qualified Control.Lens as Lens
import qualified Data.Map as Map
import Hyper.Class.Morph (HMorph (..))
import Hyper.TH.Internal.Utils
import Language.Haskell.TH
import qualified Language.Haskell.TH.Datatype as D

import Hyper.Internal.Prelude

makeHMorph :: Name -> DecsQ
makeHMorph :: Name -> DecsQ
makeHMorph Name
typeName = Name -> Q TypeInfo
makeTypeInfo Name
typeName forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TypeInfo -> DecsQ
makeHMorphForType

{-# ANN module "HLint: ignore Use id" #-}
makeHMorphForType :: TypeInfo -> DecsQ
makeHMorphForType :: TypeInfo -> DecsQ
makeHMorphForType TypeInfo
info =
    -- TODO: Contexts
    forall (m :: * -> *).
Quote m =>
m [Type] -> m Type -> [m Dec] -> m Dec
instanceD
        (forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
        [t|HMorph $(pure src) $(pure dst)|]
        [ Name -> Maybe [Q TyVarBndrUnit] -> [TypeQ] -> TypeQ -> Q Dec
D.tySynInstDCompat
            ''MorphConstraint
            (forall a. a -> Maybe a
Just [forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> TyVarBndrUnit
plainTV Name
constraintVar)])
            ([Type
src, Type
dst, Name -> Type
VarT Name
constraintVar] forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall (f :: * -> *) a. Applicative f => a -> f a
pure)
            ([Type] -> CxtQ
simplifyContext [Type]
morphConstraint forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall (t :: * -> *). Foldable t => t Type -> Type
toTuple)
        , forall (m :: * -> *).
Quote m =>
m [Type]
-> Name
-> [m Type]
-> Maybe Type
-> [m Con]
-> [m DerivClause]
-> m Dec
dataInstD
            (forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
            ''MorphWitness
            [forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
src, forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
dst, [t|_|], [t|_|]]
            forall a. Maybe a
Nothing
            (Map Type (Name, Q Con)
witnesses forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field2 s t a b => Lens s t a b
Lens._2)
            []
        , forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'morphMap (TypeInfo
-> [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors TypeInfo
info forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Name, ConstructorVariant, [Either Type CtrTypePattern])
-> Q Clause
mkMorphCon)
        , forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'morphLiftConstraint [Q Clause]
liftConstraintClauses
        ]
        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (forall a. a -> [a] -> [a]
: [])
    where
        (Map Name Type
s0, Map Name Type
s1) = TypeInfo -> (Map Name Type, Map Name Type)
paramSubsts TypeInfo
info
        src :: Type
src = forall a. TypeSubstitution a => Map Name Type -> a -> a
D.applySubstitution Map Name Type
s0 (TypeInfo -> Type
tiInstance TypeInfo
info)
        dst :: Type
dst = forall a. TypeSubstitution a => Map Name Type -> a -> a
D.applySubstitution Map Name Type
s1 (TypeInfo -> Type
tiInstance TypeInfo
info)
        constraintVar :: Name
constraintVar = String -> Name
mkName String
"constraint"
        contents :: TypeContents
contents = TypeInfo -> TypeContents
childrenTypes TypeInfo
info
        morphConstraint :: [Type]
morphConstraint =
            (TypeContents -> Set Type
tcChildren TypeContents
contents forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
Lens.folded forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Type -> Type -> Type
appSubsts (Name -> Type
VarT Name
constraintVar))
                forall a. Semigroup a => a -> a -> a
<> ( TypeContents -> Set Type
tcEmbeds TypeContents
contents forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
Lens.folded
                        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Type
x -> Name -> Type
ConT ''MorphConstraint Type -> Type -> Type
`appSubsts` Type
x Type -> Type -> Type
`AppT` Name -> Type
VarT Name
constraintVar
                   )
        appSubsts :: Type -> Type -> Type
appSubsts Type
x Type
t = Type
x Type -> Type -> Type
`AppT` forall a. TypeSubstitution a => Map Name Type -> a -> a
D.applySubstitution Map Name Type
s0 Type
t Type -> Type -> Type
`AppT` forall a. TypeSubstitution a => Map Name Type -> a -> a
D.applySubstitution Map Name Type
s1 Type
t
        nodeWits :: [(Type, (Name, Q Con))]
nodeWits =
            TypeContents -> Set Type
tcChildren TypeContents
contents forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
Lens.folded
                forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Type
x ->
                    let n :: Name
n = String
witPrefix forall a. Semigroup a => a -> a -> a
<> Type -> String
mkNiceTypeName Type
x forall a b. a -> (a -> b) -> b
& String -> Name
mkName
                    in  ( Type
x
                        , (Name
n, forall (m :: * -> *).
Quote m =>
[Name] -> [m StrictType] -> m Type -> m Con
gadtC [Name
n] [] (forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> Type -> Type
appSubsts Type
morphWithNessOf Type
x)))
                        )
        embedWits :: [(Type, (Name, Q Con))]
embedWits =
            TypeContents -> Set Type
tcEmbeds TypeContents
contents forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (f :: * -> *) a. Foldable f => IndexedFold Int (f a) a
Lens.folded
                forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Type
x ->
                    let n :: Name
n = String
witPrefix forall a. Semigroup a => a -> a -> a
<> Type -> String
mkNiceTypeName Type
x forall a b. a -> (a -> b) -> b
& String -> Name
mkName
                    in  ( Type
x
                        ,
                            ( Name
n
                            , forall (m :: * -> *).
Quote m =>
[Name] -> [m StrictType] -> m Type -> m Con
gadtC
                                [Name
n]
                                [ forall (m :: * -> *). Quote m => m Bang -> m Type -> m StrictType
bangType
                                    (forall (m :: * -> *).
Quote m =>
m SourceUnpackedness -> m SourceStrictness -> m Bang
bang forall (m :: * -> *). Quote m => m SourceUnpackedness
noSourceUnpackedness forall (m :: * -> *). Quote m => m SourceStrictness
noSourceStrictness)
                                    (forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Type
ConT ''MorphWitness Type -> Type -> Type
`appSubsts` Type
x Type -> Type -> Type
`AppT` Type
varA Type -> Type -> Type
`AppT` Type
varB))
                                ]
                                (forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
morphWithNessOf Type -> Type -> Type
`AppT` Type
varA Type -> Type -> Type
`AppT` Type
varB))
                            )
                        )
        witnesses :: Map Type (Name, Q Con)
witnesses = [(Type, (Name, Q Con))]
nodeWits forall a. Semigroup a => a -> a -> a
<> [(Type, (Name, Q Con))]
embedWits forall a b. a -> (a -> b) -> b
& forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
        varA :: Type
varA = Name -> Type
VarT (String -> Name
mkName String
"a")
        varB :: Type
varB = Name -> Type
VarT (String -> Name
mkName String
"b")
        witPrefix :: String
witPrefix = String
"M_" forall a. Semigroup a => a -> a -> a
<> Name -> String
niceName (TypeInfo -> Name
tiName TypeInfo
info) forall a. Semigroup a => a -> a -> a
<> String
"_"
        morphWithNessOf :: Type
morphWithNessOf = Name -> Type
ConT ''MorphWitness Type -> Type -> Type
`AppT` Type
src Type -> Type -> Type
`AppT` Type
dst
        liftConstraintClauses :: [Q Clause]
liftConstraintClauses
            | forall k a. Map k a -> Bool
Map.null Map Type (Name, Q Con)
witnesses = [forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [] (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (forall (m :: * -> *). Quote m => [m Match] -> m Exp
lamCaseE [])) []]
            | Bool
otherwise =
                ([(Type, (Name, Q Con))]
nodeWits forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field2 s t a b => Lens s t a b
Lens._2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall {m :: * -> *}. Quote m => Name -> m Clause
liftNodeConstraint)
                    forall a. Semigroup a => a -> a -> a
<> ([(Type, (Name, Q Con))]
embedWits forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field2 s t a b => Lens s t a b
Lens._2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall {m :: * -> *}. Quote m => Name -> m Clause
liftEmbedConstraint)
        liftNodeConstraint :: Name -> m Clause
liftNodeConstraint Name
n = forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
n [], forall (m :: * -> *). Quote m => m Pat
wildP] (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|\x -> x|]) []
        liftEmbedConstraint :: Name -> m Clause
liftEmbedConstraint Name
n =
            forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause
                [forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
n [forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
varW], forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
varProxy]
                (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|morphLiftConstraint $(varE varW) $(varE varProxy)|])
                []
        varW :: Name
varW = String -> Name
mkName String
"w"
        varProxy :: Name
varProxy = String -> Name
mkName String
"p"
        mkMorphCon :: (Name, ConstructorVariant, [Either Type CtrTypePattern])
-> Q Clause
mkMorphCon (Name, ConstructorVariant, [Either Type CtrTypePattern])
con =
            forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
varF, Q Pat
p] Q Body
b []
            where
                (Q Pat
p, Q Body
b) = forall a b c.
Int
-> Map Type (Name, a)
-> (Name, b, [Either c CtrTypePattern])
-> (Q Pat, Q Body)
morphCon Int
0 Map Type (Name, Q Con)
witnesses (Name, ConstructorVariant, [Either Type CtrTypePattern])
con

varF :: Name
varF :: Name
varF = String -> Name
mkName String
"_f"

morphCon :: Int -> Map Type (Name, a) -> (Name, b, [Either c CtrTypePattern]) -> (Q Pat, Q Body)
morphCon :: forall a b c.
Int
-> Map Type (Name, a)
-> (Name, b, [Either c CtrTypePattern])
-> (Q Pat, Q Body)
morphCon Int
i Map Type (Name, a)
witnesses (Name
n, b
_, [Either c CtrTypePattern]
fields) =
    ( forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
n ([Name]
cVars forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall (m :: * -> *). Quote m => Name -> m Pat
varP)
    , forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (forall (m :: * -> *). Quote m => Name -> m Exp
conE Name
n) (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Either c CtrTypePattern -> Name -> Q Exp
bodyFor [Either c CtrTypePattern]
fields [Name]
cVars))
    )
    where
        cVars :: [Name]
cVars =
            [Int
i ..]
                forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> String -> Name
mkName forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char
'x' forall a. a -> [a] -> [a]
:) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show
                forall a b. a -> (a -> b) -> b
& forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either c CtrTypePattern]
fields)
        f :: Q Exp
f = forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
varF
        bodyFor :: Either c CtrTypePattern -> Name -> Q Exp
bodyFor Left{} Name
v = forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
v
        bodyFor (Right CtrTypePattern
x) Name
v = [|$(bodyForPat x) $(varE v)|]
        bodyForPat :: CtrTypePattern -> Q Exp
bodyForPat (Node Type
x) = [|$f $(conE (witnesses ^?! Lens.ix x . Lens._1))|]
        bodyForPat (InContainer Type
_ CtrTypePattern
pat) = [|fmap $(bodyForPat pat)|]
        bodyForPat (FlatEmbed TypeInfo
x) =
            forall (m :: * -> *). Quote m => [m Match] -> m Exp
lamCaseE
                ( TypeInfo
-> [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors TypeInfo
x
                    forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match
                        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c.
Int
-> Map Type (Name, a)
-> (Name, b, [Either c CtrTypePattern])
-> (Q Pat, Q Body)
morphCon (Int
i forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
cVars) Map Type (Name, a)
witnesses
                        forall (f :: * -> *) a b. Functor f => f (a -> b) -> a -> f b
?? []
                )
        bodyForPat (GenEmbed Type
t) = [|morphMap ($f . $(conE (witnesses ^?! Lens.ix t . Lens._1)))|]

type MorphSubsts = (Map Name Type, Map Name Type)

paramSubsts :: TypeInfo -> MorphSubsts
paramSubsts :: TypeInfo -> (Map Name Type, Map Name Type)
paramSubsts TypeInfo
info =
    (TypeInfo -> [TyVarBndrUnit]
tiParams TypeInfo
info forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall flag. TyVarBndr_ flag -> Name
D.tvName) forall s a. s -> Getting a s a -> a
^. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to Name -> (Map Name Type, Map Name Type)
mkInfo
    where
        pinned :: Set Name
pinned = TypeInfo -> Set Name
pinnedParams TypeInfo
info
        mkInfo :: Name -> (Map Name Type, Map Name Type)
mkInfo Name
name
            | Set Name
pinned forall s a. s -> Getting a s a -> a
^. forall m. Contains m => Index m -> Lens' m Bool
Lens.contains Name
name = forall a. Monoid a => a
mempty
            | Bool
otherwise = (forall {b}.
(IxValue b ~ Type, Index b ~ Name, Monoid b, At b) =>
Name -> String -> b
side Name
name String
"0", forall {b}.
(IxValue b ~ Type, Index b ~ Name, Monoid b, At b) =>
Name -> String -> b
side Name
name String
"1")
        side :: Name -> String -> b
side Name
name String
suffix = forall a. Monoid a => a
mempty forall a b. a -> (a -> b) -> b
& forall m. At m => Index m -> Lens' m (Maybe (IxValue m))
Lens.at Name
name forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Name -> Type
VarT (String -> Name
mkName (Name -> String
nameBase Name
name forall a. Semigroup a => a -> a -> a
<> String
suffix))

pinnedParams :: TypeInfo -> Set Name
pinnedParams :: TypeInfo -> Set Name
pinnedParams = (forall s a. s -> Getting a s a -> a
^. forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to TypeInfo
-> [(Name, ConstructorVariant, [Either Type CtrTypePattern])]
tiConstructors forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Field3 s t a b => Lens s t a b
Lens._3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
Lens.to Either Type CtrTypePattern -> Set Name
ctrPinnedParams)

ctrPinnedParams :: Either Type CtrTypePattern -> Set Name
ctrPinnedParams :: Either Type CtrTypePattern -> Set Name
ctrPinnedParams (Left Type
t) = Type -> Set Name
typeParams Type
t
ctrPinnedParams (Right Node{}) = forall a. Monoid a => a
mempty
ctrPinnedParams (Right GenEmbed{}) = forall a. Monoid a => a
mempty
ctrPinnedParams (Right (FlatEmbed TypeInfo
info)) = TypeInfo -> Set Name
pinnedParams TypeInfo
info
ctrPinnedParams (Right (InContainer Type
c CtrTypePattern
p)) = Type -> Set Name
typeParams Type
c forall a. Semigroup a => a -> a -> a
<> Either Type CtrTypePattern -> Set Name
ctrPinnedParams (forall a b. b -> Either a b
Right CtrTypePattern
p)

typeParams :: Type -> Set Name
typeParams :: Type -> Set Name
typeParams (VarT Name
x) = forall a. Monoid a => a
mempty forall a b. a -> (a -> b) -> b
& forall m. Contains m => Index m -> Lens' m Bool
Lens.contains Name
x forall s t a b. ASetter s t a b -> b -> s -> t
.~ Bool
True
typeParams (AppT Type
f Type
x) = Type -> Set Name
typeParams Type
f forall a. Semigroup a => a -> a -> a
<> Type -> Set Name
typeParams Type
x
typeParams (InfixT Type
x Name
_ Type
y) = Type -> Set Name
typeParams Type
x forall a. Semigroup a => a -> a -> a
<> Type -> Set Name
typeParams Type
y
-- TODO: Missing cases
typeParams Type
_ = forall a. Monoid a => a
mempty