{-# LANGUAGE BangPatterns #-}

-- | Unification
module Hyper.Unify
    ( unify
    , module Hyper.Class.Unify
    , module Hyper.Unify.Binding
    , module Hyper.Unify.Constraints
    , module Hyper.Unify.Error
      -- | Exported only for SPECIALIZE pragmas
    , updateConstraints
    , updateTermConstraints
    , updateTermConstraintsH
    , unifyUTerms
    , unifyUnbound
    ) where

import Algebra.PartialOrd (PartialOrd (..))
import Hyper
import Hyper.Class.Unify
import Hyper.Class.ZipMatch (zipMatchA)
import Hyper.Unify.Binding (UVar)
import Hyper.Unify.Constraints
import Hyper.Unify.Error (UnifyError (..))
import Hyper.Unify.Term (UTerm (..), UTermBody (..), uBody, uConstraints)

import Hyper.Internal.Prelude

-- TODO: implement when need / better understand motivations for -
-- occursIn, seenAs, getFreeVars, freshen, equals, equiv
-- (from unification-fd package)

{-# INLINE updateConstraints #-}
updateConstraints ::
    Unify m t =>
    TypeConstraintsOf t ->
    UVarOf m # t ->
    UTerm (UVarOf m) # t ->
    m ()
updateConstraints :: forall (m :: * -> *) (t :: HyperType).
Unify m t =>
TypeConstraintsOf t
-> (UVarOf m # t) -> (UTerm (UVarOf m) # t) -> m ()
updateConstraints !TypeConstraintsOf t
newConstraints UVarOf m # t
v UTerm (UVarOf m) # t
x =
    case UTerm (UVarOf m) # t
x of
        UUnbound TypeConstraintsOf (GetHyperType ('AHyperType t))
l
            | TypeConstraintsOf t
newConstraints forall a. PartialOrd a => a -> a -> Bool
`leq` TypeConstraintsOf (GetHyperType ('AHyperType t))
l -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            | Bool
otherwise -> forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v (forall (v :: HyperType) (ast :: AHyperType).
TypeConstraintsOf (GetHyperType ast) -> UTerm v ast
UUnbound TypeConstraintsOf t
newConstraints)
        USkolem TypeConstraintsOf (GetHyperType ('AHyperType t))
l
            | TypeConstraintsOf t
newConstraints forall a. PartialOrd a => a -> a -> Bool
`leq` TypeConstraintsOf (GetHyperType ('AHyperType t))
l -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            | Bool
otherwise -> forall (t :: HyperType) (h :: AHyperType).
(h :# t) -> UnifyError t h
SkolemEscape UVarOf m # t
v forall a b. a -> (a -> b) -> b
& forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError
        UTerm UTermBody (UVarOf m) ('AHyperType t)
t -> forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTermBody (UVarOf m) # t) -> TypeConstraintsOf t -> m ()
updateTermConstraints UVarOf m # t
v UTermBody (UVarOf m) ('AHyperType t)
t TypeConstraintsOf t
newConstraints
        UResolving UTermBody (UVarOf m) ('AHyperType t)
t -> forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UVarOf m # t) -> (UTermBody (UVarOf m) # t) -> m a
occursError UVarOf m # t
v UTermBody (UVarOf m) ('AHyperType t)
t forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Functor f => f a -> f ()
void
        UTerm (UVarOf m) # t
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"updateConstraints: This shouldn't happen in unification stage"

{-# INLINE updateTermConstraints #-}
updateTermConstraints ::
    forall m t.
    Unify m t =>
    UVarOf m # t ->
    UTermBody (UVarOf m) # t ->
    TypeConstraintsOf t ->
    m ()
updateTermConstraints :: forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTermBody (UVarOf m) # t) -> TypeConstraintsOf t -> m ()
updateTermConstraints UVarOf m # t
v UTermBody (UVarOf m) # t
t TypeConstraintsOf t
newConstraints
    | TypeConstraintsOf t
newConstraints forall a. PartialOrd a => a -> a -> Bool
`leq` (UTermBody (UVarOf m) # t
t forall s a. s -> Getting a s a -> a
^. forall (v :: HyperType) (ast :: AHyperType).
Lens' (UTermBody v ast) (TypeConstraintsOf (GetHyperType ast))
uConstraints) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    | Bool
otherwise =
        do
            forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v (forall (v :: HyperType) (ast :: AHyperType).
UTermBody v ast -> UTerm v ast
UResolving UTermBody (UVarOf m) # t
t)
            case forall (ast :: HyperType) (h :: HyperType).
HasTypeConstraints ast =>
TypeConstraintsOf ast
-> (ast # h) -> Maybe (ast # WithConstraint h)
verifyConstraints TypeConstraintsOf t
newConstraints (UTermBody (UVarOf m) # t
t forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody) of
                Maybe (t # WithConstraint (UVarOf m))
Nothing -> forall (t :: HyperType) (h :: AHyperType).
t h -> TypeConstraintsOf t -> UnifyError t h
ConstraintsViolation (UTermBody (UVarOf m) # t
t forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody) TypeConstraintsOf t
newConstraints forall a b. a -> (a -> b) -> b
& forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError
                Just t # WithConstraint (UVarOf m)
prop ->
                    do
                        forall (f :: * -> *) (h :: HyperType) (m :: HyperType).
(Applicative f, HFoldable h) =>
(forall (c :: HyperType). HWitness h c -> (m # c) -> f ())
-> (h # m) -> f ()
htraverse_ (forall {k} (t :: k). Proxy t
Proxy @(Unify m) forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#> forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(WithConstraint (UVarOf m) # t) -> m ()
updateTermConstraintsH) t # WithConstraint (UVarOf m)
prop
                        forall (v :: HyperType) (ast :: AHyperType).
TypeConstraintsOf (GetHyperType ast)
-> (ast :# v) -> UTermBody v ast
UTermBody TypeConstraintsOf t
newConstraints (UTermBody (UVarOf m) # t
t forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody) forall a b. a -> (a -> b) -> b
& forall (v :: HyperType) (ast :: AHyperType).
UTermBody v ast -> UTerm v ast
UTerm forall a b. a -> (a -> b) -> b
& forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v
                        forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: HyperType).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) (forall {k} (t :: k). Proxy t
Proxy @t)

{-# INLINE updateTermConstraintsH #-}
updateTermConstraintsH ::
    Unify m t =>
    WithConstraint (UVarOf m) # t ->
    m ()
updateTermConstraintsH :: forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(WithConstraint (UVarOf m) # t) -> m ()
updateTermConstraintsH (WithConstraint TypeConstraintsOf (GetHyperType ('AHyperType t))
c UVarOf m ('AHyperType t)
v0) =
    do
        (UVarOf m ('AHyperType t)
v1, UTerm (UVarOf m) # t
x) <- forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m ('AHyperType t)
v0
        forall (m :: * -> *) (t :: HyperType).
Unify m t =>
TypeConstraintsOf t
-> (UVarOf m # t) -> (UTerm (UVarOf m) # t) -> m ()
updateConstraints TypeConstraintsOf (GetHyperType ('AHyperType t))
c UVarOf m ('AHyperType t)
v1 UTerm (UVarOf m) # t
x

-- | Unify unification variables
{-# INLINE unify #-}
unify ::
    forall m t.
    Unify m t =>
    UVarOf m # t ->
    UVarOf m # t ->
    m (UVarOf m # t)
unify :: forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> (UVarOf m # t) -> m (UVarOf m # t)
unify UVarOf m # t
x0 UVarOf m # t
y0
    | UVarOf m # t
x0 forall a. Eq a => a -> a -> Bool
== UVarOf m # t
y0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure UVarOf m # t
x0
    | Bool
otherwise =
        do
            (UVarOf m # t
x1, UTerm (UVarOf m) # t
xu) <- forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
x0
            if UVarOf m # t
x1 forall a. Eq a => a -> a -> Bool
== UVarOf m # t
y0
                then forall (f :: * -> *) a. Applicative f => a -> f a
pure UVarOf m # t
x1
                else do
                    (UVarOf m # t
y1, UTerm (UVarOf m) # t
yu) <- forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
y0
                    if UVarOf m # t
x1 forall a. Eq a => a -> a -> Bool
== UVarOf m # t
y1
                        then forall (f :: * -> *) a. Applicative f => a -> f a
pure UVarOf m # t
x1
                        else forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> (UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> m (UVarOf m # t)
unifyUTerms UVarOf m # t
x1 UTerm (UVarOf m) # t
xu UVarOf m # t
y1 UTerm (UVarOf m) # t
yu

{-# INLINE unifyUnbound #-}
unifyUnbound ::
    Unify m t =>
    WithConstraint (UVarOf m) # t ->
    UVarOf m # t ->
    UTerm (UVarOf m) # t ->
    m (UVarOf m # t)
unifyUnbound :: forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(WithConstraint (UVarOf m) # t)
-> (UVarOf m # t) -> (UTerm (UVarOf m) # t) -> m (UVarOf m # t)
unifyUnbound (WithConstraint TypeConstraintsOf (GetHyperType ('AHyperType t))
level UVarOf m ('AHyperType t)
xv) UVarOf m ('AHyperType t)
yv UTerm (UVarOf m) # t
yt =
    do
        forall (m :: * -> *) (t :: HyperType).
Unify m t =>
TypeConstraintsOf t
-> (UVarOf m # t) -> (UTerm (UVarOf m) # t) -> m ()
updateConstraints TypeConstraintsOf (GetHyperType ('AHyperType t))
level UVarOf m ('AHyperType t)
yv UTerm (UVarOf m) # t
yt
        UVarOf m ('AHyperType t)
yv forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m ('AHyperType t)
xv (forall (v :: HyperType) (ast :: AHyperType). v ast -> UTerm v ast
UToVar UVarOf m ('AHyperType t)
yv)

{-# INLINE unifyUTerms #-}
unifyUTerms ::
    forall m t.
    Unify m t =>
    UVarOf m # t ->
    UTerm (UVarOf m) # t ->
    UVarOf m # t ->
    UTerm (UVarOf m) # t ->
    m (UVarOf m # t)
unifyUTerms :: forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> (UVarOf m # t)
-> (UTerm (UVarOf m) # t)
-> m (UVarOf m # t)
unifyUTerms UVarOf m # t
xv (UUnbound TypeConstraintsOf (GetHyperType ('AHyperType t))
level) UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
yt = forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(WithConstraint (UVarOf m) # t)
-> (UVarOf m # t) -> (UTerm (UVarOf m) # t) -> m (UVarOf m # t)
unifyUnbound (forall (h :: HyperType) (ast :: AHyperType).
TypeConstraintsOf (GetHyperType ast)
-> h ast -> WithConstraint h ast
WithConstraint TypeConstraintsOf (GetHyperType ('AHyperType t))
level UVarOf m # t
xv) UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
yt
unifyUTerms UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
xt UVarOf m # t
yv (UUnbound TypeConstraintsOf (GetHyperType ('AHyperType t))
level) = forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(WithConstraint (UVarOf m) # t)
-> (UVarOf m # t) -> (UTerm (UVarOf m) # t) -> m (UVarOf m # t)
unifyUnbound (forall (h :: HyperType) (ast :: AHyperType).
TypeConstraintsOf (GetHyperType ast)
-> h ast -> WithConstraint h ast
WithConstraint TypeConstraintsOf (GetHyperType ('AHyperType t))
level UVarOf m # t
yv) UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
xt
unifyUTerms UVarOf m # t
xv USkolem{} UVarOf m # t
yv UTerm (UVarOf m) ('AHyperType t)
_ = UVarOf m # t
xv forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: HyperType) (h :: AHyperType).
(h :# t) -> (h :# t) -> UnifyError t h
SkolemUnified UVarOf m # t
xv UVarOf m # t
yv)
unifyUTerms UVarOf m # t
xv UTerm (UVarOf m) ('AHyperType t)
_ UVarOf m # t
yv USkolem{} = UVarOf m # t
yv forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: HyperType) (h :: AHyperType).
(h :# t) -> (h :# t) -> UnifyError t h
SkolemUnified UVarOf m # t
yv UVarOf m # t
xv)
unifyUTerms UVarOf m # t
xv (UTerm UTermBody (UVarOf m) ('AHyperType t)
xt) UVarOf m # t
yv (UTerm UTermBody (UVarOf m) ('AHyperType t)
yt) =
    do
        forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
yv (forall (v :: HyperType) (ast :: AHyperType). v ast -> UTerm v ast
UToVar UVarOf m # t
xv)
        forall (f :: * -> *) (h :: HyperType) (p :: HyperType)
       (q :: HyperType) (r :: HyperType).
(Applicative f, ZipMatch h, HTraversable h) =>
(forall (n :: HyperType).
 HWitness h n -> (p # n) -> (q # n) -> f (r # n))
-> (h # p) -> (h # q) -> Maybe (f (h # r))
zipMatchA (forall {k} (t :: k). Proxy t
Proxy @(Unify m) forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#> forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> (UVarOf m # t) -> m (UVarOf m # t)
unify) (UTermBody (UVarOf m) ('AHyperType t)
xt forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody) (UTermBody (UVarOf m) ('AHyperType t)
yt forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody)
            forall a b. a -> (a -> b) -> b
& forall a. a -> Maybe a -> a
fromMaybe (UTermBody (UVarOf m) ('AHyperType t)
xt forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(forall (c :: HyperType).
 Unify m c =>
 (UVarOf m # c) -> (UVarOf m # c) -> m (UVarOf m # c))
-> (t # UVarOf m) -> (t # UVarOf m) -> m ()
structureMismatch forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> (UVarOf m # t) -> m (UVarOf m # t)
unify (UTermBody (UVarOf m) ('AHyperType t)
xt forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody) (UTermBody (UVarOf m) ('AHyperType t)
yt forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody))
            forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
xv forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: HyperType) (ast :: AHyperType).
UTermBody v ast -> UTerm v ast
UTerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: HyperType) (ast :: AHyperType).
TypeConstraintsOf (GetHyperType ast)
-> (ast :# v) -> UTermBody v ast
UTermBody (UTermBody (UVarOf m) ('AHyperType t)
xt forall s a. s -> Getting a s a -> a
^. forall (v :: HyperType) (ast :: AHyperType).
Lens' (UTermBody v ast) (TypeConstraintsOf (GetHyperType ast))
uConstraints forall a. Semigroup a => a -> a -> a
<> UTermBody (UVarOf m) ('AHyperType t)
yt forall s a. s -> Getting a s a -> a
^. forall (v :: HyperType) (ast :: AHyperType).
Lens' (UTermBody v ast) (TypeConstraintsOf (GetHyperType ast))
uConstraints)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure UVarOf m # t
xv
        forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: HyperType).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) (forall {k} (t :: k). Proxy t
Proxy @t)
unifyUTerms UVarOf m # t
_ UTerm (UVarOf m) ('AHyperType t)
_ UVarOf m # t
_ UTerm (UVarOf m) ('AHyperType t)
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"unifyUTerms: This shouldn't happen in unification stage"