{-# LANGUAGE FlexibleContexts #-}

-- | Alpha-equality for schemes
module Hyper.Syntax.Scheme.AlphaEq
    ( alphaEq
    ) where

import Control.Lens (ix)
import Hyper
import Hyper.Class.Optic (HNodeLens (..))
import Hyper.Class.ZipMatch (zipMatch_)
import Hyper.Recurse (wrapM, (#>>))
import Hyper.Syntax.Scheme
import Hyper.Unify
import Hyper.Unify.New (newTerm)
import Hyper.Unify.QuantifiedVar
import Hyper.Unify.Term (UTerm (..), uBody)

import Hyper.Internal.Prelude

makeQVarInstancesInScope ::
    forall m typ.
    UnifyGen m typ =>
    QVars # typ ->
    m (QVarInstances (UVarOf m) # typ)
makeQVarInstancesInScope :: forall (m :: * -> *) (typ :: HyperType).
UnifyGen m typ =>
(QVars # typ) -> m (QVarInstances (UVarOf m) # typ)
makeQVarInstancesInScope (QVars Map
  (QVar (GetHyperType ('AHyperType typ)))
  (TypeConstraintsOf (GetHyperType ('AHyperType typ)))
foralls) =
    forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse TypeConstraintsOf typ -> m (UVarOf m # typ)
makeSkolem Map
  (QVar (GetHyperType ('AHyperType typ)))
  (TypeConstraintsOf (GetHyperType ('AHyperType typ)))
foralls forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall (h :: HyperType) (typ :: AHyperType).
Map (QVar (GetHyperType typ)) (h typ) -> QVarInstances h typ
QVarInstances
    where
        makeSkolem :: TypeConstraintsOf typ -> m (UVarOf m # typ)
makeSkolem TypeConstraintsOf typ
c = forall (m :: * -> *) (t :: HyperType).
UnifyGen m t =>
Proxy t -> m (TypeConstraintsOf t)
scopeConstraints (forall {k} (t :: k). Proxy t
Proxy @typ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (UTerm v # t) -> m (v # t)
newVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: HyperType) (ast :: AHyperType).
TypeConstraintsOf (GetHyperType ast) -> UTerm v ast
USkolem forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TypeConstraintsOf typ
c forall a. Semigroup a => a -> a -> a
<>)

schemeBodyToType ::
    (UnifyGen m typ, HNodeLens varTypes typ, Ord (QVar typ)) =>
    varTypes # QVarInstances (UVarOf m) ->
    typ # UVarOf m ->
    m (UVarOf m # typ)
schemeBodyToType :: forall (m :: * -> *) (typ :: HyperType) (varTypes :: HyperType).
(UnifyGen m typ, HNodeLens varTypes typ, Ord (QVar typ)) =>
(varTypes # QVarInstances (UVarOf m))
-> (typ # UVarOf m) -> m (UVarOf m # typ)
schemeBodyToType varTypes # QVarInstances (UVarOf m)
foralls typ # UVarOf m
x =
    case typ # UVarOf m
x forall s a. s -> Getting (First a) s a -> Maybe a
^? forall (t :: HyperType) (f :: AHyperType).
HasQuantifiedVar t =>
Prism' (t f) (QVar t)
quantifiedVar forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= QVar typ -> Maybe (UVarOf m ('AHyperType typ))
getForAll of
        Maybe (UVarOf m ('AHyperType typ))
Nothing -> forall (m :: * -> *) (t :: HyperType).
UnifyGen m t =>
(t # UVarOf m) -> m (UVarOf m # t)
newTerm typ # UVarOf m
x
        Just UVarOf m ('AHyperType typ)
r -> forall (f :: * -> *) a. Applicative f => a -> f a
pure UVarOf m ('AHyperType typ)
r
    where
        getForAll :: QVar typ -> Maybe (UVarOf m ('AHyperType typ))
getForAll QVar typ
v = varTypes # QVarInstances (UVarOf m)
foralls forall s a. s -> Getting (First a) s a -> Maybe a
^? forall (s :: HyperType) (a :: HyperType) (h :: HyperType).
HNodeLens s a =>
Lens' (s # h) (h # a)
hNodeLens forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (h1 :: HyperType) (typ1 :: AHyperType) (h2 :: HyperType)
       (typ2 :: AHyperType).
Iso
  (QVarInstances h1 typ1)
  (QVarInstances h2 typ2)
  (Map (QVar (GetHyperType typ1)) (h1 typ1))
  (Map (QVar (GetHyperType typ2)) (h2 typ2))
_QVarInstances forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix QVar typ
v

schemeToRestrictedType ::
    forall m varTypes typ.
    ( HTraversable varTypes
    , HNodesConstraint varTypes (UnifyGen m)
    , HasScheme varTypes m typ
    ) =>
    Pure # Scheme varTypes typ ->
    m (UVarOf m # typ)
schemeToRestrictedType :: forall (m :: * -> *) (varTypes :: HyperType) (typ :: HyperType).
(HTraversable varTypes, HNodesConstraint varTypes (UnifyGen m),
 HasScheme varTypes m typ) =>
(Pure # Scheme varTypes typ) -> m (UVarOf m # typ)
schemeToRestrictedType (Pure (Scheme varTypes # QVars
vars 'AHyperType Pure :# typ
typ)) =
    do
        varTypes # QVarInstances (UVarOf m)
foralls <- forall (f :: * -> *) (h :: HyperType) (p :: HyperType)
       (q :: HyperType).
(Applicative f, HTraversable h) =>
(forall (n :: HyperType). HWitness h n -> (p # n) -> f (q # n))
-> (h # p) -> f (h # q)
htraverse (forall {k} (t :: k). Proxy t
Proxy @(UnifyGen m) forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#> forall (m :: * -> *) (typ :: HyperType).
UnifyGen m typ =>
(QVars # typ) -> m (QVarInstances (UVarOf m) # typ)
makeQVarInstancesInScope) varTypes # QVars
vars
        forall (m :: * -> *) (h :: HyperType) (w :: HyperType).
(Monad m, RTraversable h) =>
(forall (n :: HyperType). HRecWitness h n -> (n # w) -> m (w # n))
-> (Pure # h) -> m (w # h)
wrapM (forall {k} (t :: k). Proxy t
Proxy @(HasScheme varTypes m) forall (c :: HyperType -> Constraint) (h :: HyperType)
       (n :: HyperType) r.
(Recursive c, c h, RNodes h) =>
Proxy c -> (c n => r) -> HRecWitness h n -> r
#>> forall (m :: * -> *) (typ :: HyperType) (varTypes :: HyperType).
(UnifyGen m typ, HNodeLens varTypes typ, Ord (QVar typ)) =>
(varTypes # QVarInstances (UVarOf m))
-> (typ # UVarOf m) -> m (UVarOf m # typ)
schemeBodyToType varTypes # QVarInstances (UVarOf m)
foralls) 'AHyperType Pure :# typ
typ

goUTerm ::
    forall m t.
    Unify m t =>
    UVarOf m # t ->
    UTerm (UVarOf m) # t ->
    UVarOf m # t ->
    UTerm (UVarOf m) # t ->
    m ()
goUTerm :: forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> (UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> m ()
goUTerm UVarOf m # t
xv USkolem{} UVarOf m # t
yv USkolem{} =
    do
        forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
xv (forall (v :: HyperType) (ast :: AHyperType). v ast -> UTerm v ast
UInstantiated UVarOf m # t
yv)
        forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
yv (forall (v :: HyperType) (ast :: AHyperType). v ast -> UTerm v ast
UInstantiated UVarOf m # t
xv)
goUTerm UVarOf m # t
xv (UInstantiated UVarOf m # t
xt) UVarOf m # t
yv (UInstantiated UVarOf m # t
yt)
    | UVarOf m # t
xv forall a. Eq a => a -> a -> Bool
== UVarOf m # t
yt Bool -> Bool -> Bool
&& UVarOf m # t
yv forall a. Eq a => a -> a -> Bool
== UVarOf m # t
xt = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    | Bool
otherwise = forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: HyperType) (h :: AHyperType).
(h :# t) -> UnifyError t h
SkolemEscape UVarOf m # t
xv)
goUTerm UVarOf m # t
xv USkolem{} UVarOf m # t
yv UUnbound{} = forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
yv (forall (v :: HyperType) (ast :: AHyperType). v ast -> UTerm v ast
UToVar UVarOf m # t
xv)
goUTerm UVarOf m # t
xv UUnbound{} UVarOf m # t
yv USkolem{} = forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
xv (forall (v :: HyperType) (ast :: AHyperType). v ast -> UTerm v ast
UToVar UVarOf m # t
yv)
goUTerm UVarOf m # t
xv UInstantiated{} UVarOf m # t
yv UUnbound{} = forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
yv (forall (v :: HyperType) (ast :: AHyperType). v ast -> UTerm v ast
UToVar UVarOf m # t
xv)
goUTerm UVarOf m # t
xv UUnbound{} UVarOf m # t
yv UInstantiated{} = forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
xv (forall (v :: HyperType) (ast :: AHyperType). v ast -> UTerm v ast
UToVar UVarOf m # t
yv)
goUTerm UVarOf m # t
_ (UToVar UVarOf m # t
xv) UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
yu =
    do
        UTerm (UVarOf m) ('AHyperType t)
xu <- forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> m (UTerm v # t)
lookupVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
xv
        forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> (UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> m ()
goUTerm UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
xu UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
yu
goUTerm UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
xu UVarOf m # t
_ (UToVar UVarOf m # t
yv) =
    do
        UTerm (UVarOf m) ('AHyperType t)
yu <- forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> m (UTerm v # t)
lookupVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
yv
        forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> (UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> m ()
goUTerm UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
xu UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
yu
goUTerm UVarOf m # t
xv USkolem{} UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
_ = forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: HyperType) (h :: AHyperType).
(h :# t) -> (h :# t) -> UnifyError t h
SkolemUnified UVarOf m # t
xv UVarOf m # t
yv)
goUTerm UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
_ UVarOf m # t
yv USkolem{} = forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: HyperType) (h :: AHyperType).
(h :# t) -> (h :# t) -> UnifyError t h
SkolemUnified UVarOf m # t
yv UVarOf m # t
xv)
goUTerm UVarOf m # t
xv UInstantiated{} UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
_ = forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: HyperType) (h :: AHyperType).
(h :# t) -> (h :# t) -> UnifyError t h
SkolemUnified UVarOf m # t
xv UVarOf m # t
yv)
goUTerm UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
_ UVarOf m # t
yv UInstantiated{} = forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: HyperType) (h :: AHyperType).
(h :# t) -> (h :# t) -> UnifyError t h
SkolemUnified UVarOf m # t
yv UVarOf m # t
xv)
goUTerm UVarOf m # t
xv UUnbound{} UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
yu = forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> (UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> m ()
goUTerm UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
yu UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
yu -- Term created in structure mismatch
goUTerm UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
xu UVarOf m # t
yv UUnbound{} = forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> (UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> m ()
goUTerm UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
xu UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
xu -- Term created in structure mismatch
goUTerm UVarOf m # t
_ (UTerm UTermBody (UVarOf m) ('AHyperType t)
xt) UVarOf m # t
_ (UTerm UTermBody (UVarOf m) ('AHyperType t)
yt) =
    forall (f :: * -> *) (h :: HyperType) (p :: HyperType)
       (q :: HyperType).
(Applicative f, ZipMatch h, HFoldable h) =>
(forall (n :: HyperType).
 HWitness h n -> (p # n) -> (q # n) -> f ())
-> (h # p) -> (h # q) -> Maybe (f ())
zipMatch_ (forall {k} (t :: k). Proxy t
Proxy @(Unify m) forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#> forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> (UVarOf m # t) -> m ()
goUVar) (UTermBody (UVarOf m) ('AHyperType t)
xt forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody) (UTermBody (UVarOf m) ('AHyperType t)
yt forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody)
        forall a b. a -> (a -> b) -> b
& forall a. a -> Maybe a -> a
fromMaybe (forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(forall (c :: HyperType).
 Unify m c =>
 (UVarOf m # c) -> (UVarOf m # c) -> m (UVarOf m # c))
-> (t # UVarOf m) -> (t # UVarOf m) -> m ()
structureMismatch (\UVarOf m # c
x UVarOf m # c
y -> UVarOf m # c
x forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> (UVarOf m # t) -> m ()
goUVar UVarOf m # c
x UVarOf m # c
y) (UTermBody (UVarOf m) ('AHyperType t)
xt forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody) (UTermBody (UVarOf m) ('AHyperType t)
yt forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody))
        forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: HyperType).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) (forall {k} (t :: k). Proxy t
Proxy @t)
goUTerm UVarOf m # t
_ UTerm (UVarOf m) ('AHyperType t)
_ UVarOf m # t
_ UTerm (UVarOf m) ('AHyperType t)
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"unexpected state at alpha-eq"

goUVar ::
    Unify m t =>
    UVarOf m # t ->
    UVarOf m # t ->
    m ()
goUVar :: forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> (UVarOf m # t) -> m ()
goUVar UVarOf m # t
xv UVarOf m # t
yv =
    do
        UTerm (UVarOf m) # t
xu <- forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> m (UTerm v # t)
lookupVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
xv
        UTerm (UVarOf m) # t
yu <- forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> m (UTerm v # t)
lookupVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
yv
        forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> (UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> m ()
goUTerm UVarOf m # t
xv UTerm (UVarOf m) # t
xu UVarOf m # t
yv UTerm (UVarOf m) # t
yu

-- Check for alpha equality. Raises a `unifyError` when mismatches.
alphaEq ::
    ( HTraversable varTypes
    , HNodesConstraint varTypes (UnifyGen m)
    , HasScheme varTypes m typ
    ) =>
    Pure # Scheme varTypes typ ->
    Pure # Scheme varTypes typ ->
    m ()
alphaEq :: forall (varTypes :: HyperType) (m :: * -> *) (typ :: HyperType).
(HTraversable varTypes, HNodesConstraint varTypes (UnifyGen m),
 HasScheme varTypes m typ) =>
(Pure # Scheme varTypes typ)
-> (Pure # Scheme varTypes typ) -> m ()
alphaEq Pure # Scheme varTypes typ
s0 Pure # Scheme varTypes typ
s1 =
    do
        UVarOf m # typ
t0 <- forall (m :: * -> *) (varTypes :: HyperType) (typ :: HyperType).
(HTraversable varTypes, HNodesConstraint varTypes (UnifyGen m),
 HasScheme varTypes m typ) =>
(Pure # Scheme varTypes typ) -> m (UVarOf m # typ)
schemeToRestrictedType Pure # Scheme varTypes typ
s0
        UVarOf m # typ
t1 <- forall (m :: * -> *) (varTypes :: HyperType) (typ :: HyperType).
(HTraversable varTypes, HNodesConstraint varTypes (UnifyGen m),
 HasScheme varTypes m typ) =>
(Pure # Scheme varTypes typ) -> m (UVarOf m # typ)
schemeToRestrictedType Pure # Scheme varTypes typ
s1
        forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> (UVarOf m # t) -> m ()
goUVar UVarOf m # typ
t0 UVarOf m # typ
t1