-- | Occurs check (check whether unification terms recursively contains themselves)

module Hyper.Unify.Occurs
    ( occursCheck
    ) where

import Control.Monad (unless, when)
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State (execStateT, get, put)
import Hyper
import Hyper.Class.Unify (Unify(..), UVarOf, BindingDict(..), semiPruneLookup, occursError)
import Hyper.Unify.Term (UTerm(..), uBody)

import Hyper.Internal.Prelude

-- | Occurs check
{-# INLINE occursCheck #-}
occursCheck ::
    forall m t.
    Unify m t =>
    UVarOf m # t -> m ()
occursCheck :: (UVarOf m # t) -> m ()
occursCheck UVarOf m # t
v0 =
    (UVarOf m # t)
-> m (UVarOf m # t, UTerm (UVarOf m) ('AHyperType t))
forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
v0
    m (UVarOf m # t, UTerm (UVarOf m) ('AHyperType t))
-> ((UVarOf m # t, UTerm (UVarOf m) ('AHyperType t)) -> m ())
-> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
    \(UVarOf m # t
v1, UTerm (UVarOf m) ('AHyperType t)
x) ->
    case UTerm (UVarOf m) ('AHyperType t)
x of
    UResolving UTermBody (UVarOf m) ('AHyperType t)
t -> (UVarOf m # t) -> UTermBody (UVarOf m) ('AHyperType t) -> m ()
forall (m :: * -> *) (t :: HyperType) a.
Unify m t =>
(UVarOf m # t) -> (UTermBody (UVarOf m) # t) -> m a
occursError UVarOf m # t
v1 UTermBody (UVarOf m) ('AHyperType t)
t
    UResolved{} -> () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    UUnbound{} -> () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    USkolem{} -> () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    UTerm UTermBody (UVarOf m) ('AHyperType t)
b ->
        Dict (HNodesConstraint t (Unify m))
-> (HNodesConstraint t (Unify m) => m ()) -> m ()
forall (c :: Constraint) e r. HasDict c e => e -> (c => r) -> r
withDict (Proxy m -> Proxy t -> Dict (HNodesConstraint t (Unify m))
forall (m :: * -> *) (t :: HyperType).
Unify m t =>
Proxy m -> Proxy t -> Dict (HNodesConstraint t (Unify m))
unifyRecursive (Proxy m
forall k (t :: k). Proxy t
Proxy @m) (Proxy t
forall k (t :: k). Proxy t
Proxy @t)) ((HNodesConstraint t (Unify m) => m ()) -> m ())
-> (HNodesConstraint t (Unify m) => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$
        (forall (c :: HyperType).
 HWitness t c -> (UVarOf m # c) -> StateT Bool m ())
-> (t # UVarOf m) -> StateT Bool m ()
forall (f :: * -> *) (h :: HyperType) (m :: HyperType).
(Applicative f, HFoldable h) =>
(forall (c :: HyperType). HWitness h c -> (m # c) -> f ())
-> (h # m) -> f ()
htraverse_
        ( Proxy (Unify m)
forall k (t :: k). Proxy t
Proxy @(Unify m) Proxy (Unify m)
-> (Unify m c => (UVarOf m # c) -> StateT Bool m ())
-> HWitness t c
-> (UVarOf m # c)
-> StateT Bool m ()
forall (h :: HyperType) (c :: HyperType -> Constraint)
       (n :: HyperType) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#>
            \UVarOf m # c
c ->
            do
                StateT Bool m Bool
forall (m :: * -> *) s. Monad m => StateT s m s
get StateT Bool m Bool
-> (Bool -> StateT Bool m ()) -> StateT Bool m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m () -> StateT Bool m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT Bool m ())
-> (Bool -> m ()) -> Bool -> StateT Bool m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
`unless` BindingDict (UVarOf m) m t
-> (UVarOf m # t) -> UTerm (UVarOf m) ('AHyperType t) -> m ()
forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar BindingDict (UVarOf m) m t
forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v1 (UTermBody (UVarOf m) ('AHyperType t)
-> UTerm (UVarOf m) ('AHyperType t)
forall (v :: HyperType) (ast :: AHyperType).
UTermBody v ast -> UTerm v ast
UResolving UTermBody (UVarOf m) ('AHyperType t)
b))
                Bool -> StateT Bool m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put Bool
True
                (UVarOf m # c) -> m ()
forall (m :: * -> *) (t :: HyperType).
Unify m t =>
(UVarOf m # t) -> m ()
occursCheck UVarOf m # c
c m () -> (m () -> StateT Bool m ()) -> StateT Bool m ()
forall a b. a -> (a -> b) -> b
& m () -> StateT Bool m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
        ) (UTermBody (UVarOf m) ('AHyperType t)
b UTermBody (UVarOf m) ('AHyperType t)
-> Getting
     (t # UVarOf m)
     (UTermBody (UVarOf m) ('AHyperType t))
     (t # UVarOf m)
-> t # UVarOf m
forall s a. s -> Getting a s a -> a
^. Getting
  (t # UVarOf m)
  (UTermBody (UVarOf m) ('AHyperType t))
  (t # UVarOf m)
forall (v1 :: HyperType) (ast :: AHyperType) (v2 :: HyperType).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody)
        StateT Bool m () -> (StateT Bool m () -> m Bool) -> m Bool
forall a b. a -> (a -> b) -> b
& (StateT Bool m () -> Bool -> m Bool
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
`execStateT` Bool
False)
        m Bool -> (Bool -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
`when` BindingDict (UVarOf m) m t
-> (UVarOf m # t) -> UTerm (UVarOf m) ('AHyperType t) -> m ()
forall (v :: HyperType) (m :: * -> *) (t :: HyperType).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar BindingDict (UVarOf m) m t
forall (m :: * -> *) (t :: HyperType).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v1 (UTermBody (UVarOf m) ('AHyperType t)
-> UTerm (UVarOf m) ('AHyperType t)
forall (v :: HyperType) (ast :: AHyperType).
UTermBody v ast -> UTerm v ast
UTerm UTermBody (UVarOf m) ('AHyperType t)
b))
    UToVar{} -> [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"lookup not expected to result in var (in occursCheck)"
    UConverted{} -> [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"conversion state not expected in occursCheck"
    UInstantiated{} -> [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"occursCheck during instantiation"