-- | Occurs check (check whether unification terms recursively contains themselves)
module Hyper.Unify.Occurs
    ( occursCheck
    ) where

import Control.Monad (unless, when)
import Control.Monad.Trans.Class (MonadTrans (..))
import Control.Monad.Trans.State (execStateT, get, put)
import Hyper
import Hyper.Class.Unify (BindingDict (..), UVarOf, Unify (..), occursError, semiPruneLookup)
import Hyper.Unify.Term (UTerm (..), uBody)

import Hyper.Internal.Prelude

-- | Occurs check
{-# INLINE occursCheck #-}
occursCheck ::
    forall m t.
    Unify m t =>
    UVarOf m # t ->
    m ()
occursCheck :: forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> m ()
occursCheck UVarOf m # t
v0 =
    do
        (UVarOf m # t
v1, UTerm (UVarOf m) ('AHyperType t)
x) <- forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
v0
        case UTerm (UVarOf m) ('AHyperType t)
x of
            UResolving UTermBody (UVarOf m) ('AHyperType t)
t -> forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UVarOf m # t) -> (UTermBody (UVarOf m) # t) -> m a
occursError UVarOf m # t
v1 UTermBody (UVarOf m) ('AHyperType t)
t
            UResolved{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            UUnbound{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            USkolem{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            UTerm UTermBody (UVarOf m) ('AHyperType t)
b ->
                forall (f :: * -> *) (h :: HyperType) (m :: HyperType).
(Applicative f, HFoldable h) =>
(forall (c :: HyperType). HWitness h c -> (m # c) -> f ())
-> (h # m) -> f ()
htraverse_
                    ( forall {k} (t :: k). Proxy t
Proxy @(Unify m) forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#>
                        \UVarOf m # c
c ->
                            do
                                forall (m :: * -> *) s. Monad m => StateT s m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
`unless` forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v1 (forall (v :: HyperType) (ast :: AHyperType).
UTermBody v ast -> UTerm v ast
UResolving UTermBody (UVarOf m) ('AHyperType t)
b))
                                forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put Bool
True
                                forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> m ()
occursCheck UVarOf m # c
c forall a b. a -> (a -> b) -> b
& forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
                    )
                    (UTermBody (UVarOf m) ('AHyperType t)
b forall s a. s -> Getting a s a -> a
^. forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody)
                    forall a b. a -> (a -> b) -> b
& (forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
`execStateT` Bool
False)
                    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
`when` forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v1 (forall (v :: HyperType) (ast :: AHyperType).
UTermBody v ast -> UTerm v ast
UTerm UTermBody (UVarOf m) ('AHyperType t)
b))
                    forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: HyperType).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) (forall {k} (t :: k). Proxy t
Proxy @t)
            UToVar{} -> forall a. HasCallStack => [Char] -> a
error [Char]
"lookup not expected to result in var (in occursCheck)"
            UConverted{} -> forall a. HasCallStack => [Char] -> a
error [Char]
"conversion state not expected in occursCheck"
            UInstantiated{} -> forall a. HasCallStack => [Char] -> a
error [Char]
"occursCheck during instantiation"