{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell   #-}

module Data.ECTA.Internal.ECTA.Enumeration (
    TermFragment(..)
  , termFragToTruncatedTerm

  , SuspendedConstraint(..)
  , scGetPathTrie
  , scGetUVar
  , descendScs
  , UVarValue(..)

  , EnumerationState(..)
  , uvarCounter
  , uvarRepresentative
  , uvarValues
  , initEnumerationState


  , EnumerateM
  , getUVarRepresentative
  , assimilateUvarVal
  , mergeNodeIntoUVarVal
  , getUVarValue
  , getTermFragForUVar
  , runEnumerateM


  , enumerateNode
  , enumerateEdge
  , firstExpandableUVar
  , enumerateOutUVar
  , enumerateOutFirstExpandableUVar
  , enumerateFully
  , expandTermFrag
  , expandUVar

  , getAllTruncatedTerms
  , getAllTerms
  , naiveDenotation
  ) where

import Control.Monad ( forM_, guard )
import Control.Monad.State.Strict ( StateT(..) )
import qualified Data.IntMap as IntMap
import Data.Maybe ( fromMaybe, isJust )
import Data.Monoid ( Any(..) )
import Data.Semigroup ( Max(..) )
import           Data.Sequence ( Seq((:<|), (:|>)) )
import qualified Data.Sequence as Sequence
import Control.Monad.Identity ( Identity )

import Control.Lens ( use, ix, (%=), (.=) )
import Control.Lens.TH ( makeLenses )
import           Pipes
import qualified Pipes.Prelude as Pipes

import Data.List.Index ( imapM )

import Data.ECTA.Internal.ECTA.Operations
import Data.ECTA.Internal.ECTA.Type
import Data.ECTA.Paths
import Data.ECTA.Term
import           Data.Persistent.UnionFind ( UnionFind, UVar, uvarToInt, intToUVar, UVarGen )
import qualified Data.Persistent.UnionFind as UnionFind
import Data.Text.Extended.Pretty

-------------------------------------------------------------------------------


---------------------------------------------------------------------------
------------------------------- Term fragments ----------------------------
---------------------------------------------------------------------------

data TermFragment = TermFragmentNode !Symbol ![TermFragment]
                  | TermFragmentUVar UVar
  deriving ( TermFragment -> TermFragment -> Bool
(TermFragment -> TermFragment -> Bool)
-> (TermFragment -> TermFragment -> Bool) -> Eq TermFragment
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TermFragment -> TermFragment -> Bool
$c/= :: TermFragment -> TermFragment -> Bool
== :: TermFragment -> TermFragment -> Bool
$c== :: TermFragment -> TermFragment -> Bool
Eq, Eq TermFragment
Eq TermFragment
-> (TermFragment -> TermFragment -> Ordering)
-> (TermFragment -> TermFragment -> Bool)
-> (TermFragment -> TermFragment -> Bool)
-> (TermFragment -> TermFragment -> Bool)
-> (TermFragment -> TermFragment -> Bool)
-> (TermFragment -> TermFragment -> TermFragment)
-> (TermFragment -> TermFragment -> TermFragment)
-> Ord TermFragment
TermFragment -> TermFragment -> Bool
TermFragment -> TermFragment -> Ordering
TermFragment -> TermFragment -> TermFragment
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TermFragment -> TermFragment -> TermFragment
$cmin :: TermFragment -> TermFragment -> TermFragment
max :: TermFragment -> TermFragment -> TermFragment
$cmax :: TermFragment -> TermFragment -> TermFragment
>= :: TermFragment -> TermFragment -> Bool
$c>= :: TermFragment -> TermFragment -> Bool
> :: TermFragment -> TermFragment -> Bool
$c> :: TermFragment -> TermFragment -> Bool
<= :: TermFragment -> TermFragment -> Bool
$c<= :: TermFragment -> TermFragment -> Bool
< :: TermFragment -> TermFragment -> Bool
$c< :: TermFragment -> TermFragment -> Bool
compare :: TermFragment -> TermFragment -> Ordering
$ccompare :: TermFragment -> TermFragment -> Ordering
$cp1Ord :: Eq TermFragment
Ord, Int -> TermFragment -> ShowS
[TermFragment] -> ShowS
TermFragment -> String
(Int -> TermFragment -> ShowS)
-> (TermFragment -> String)
-> ([TermFragment] -> ShowS)
-> Show TermFragment
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TermFragment] -> ShowS
$cshowList :: [TermFragment] -> ShowS
show :: TermFragment -> String
$cshow :: TermFragment -> String
showsPrec :: Int -> TermFragment -> ShowS
$cshowsPrec :: Int -> TermFragment -> ShowS
Show )

termFragToTruncatedTerm :: TermFragment -> Term
termFragToTruncatedTerm :: TermFragment -> Term
termFragToTruncatedTerm (TermFragmentNode Symbol
s [TermFragment]
ts) = Symbol -> [Term] -> Term
Term Symbol
s ((TermFragment -> Term) -> [TermFragment] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map TermFragment -> Term
termFragToTruncatedTerm [TermFragment]
ts)
termFragToTruncatedTerm (TermFragmentUVar UVar
uv)   = Symbol -> [Term] -> Term
Term (Text -> Symbol
Symbol (Text -> Symbol) -> Text -> Symbol
forall a b. (a -> b) -> a -> b
$ Text
"v" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
pretty (UVar -> Int
uvarToInt UVar
uv)) []

---------------------------------------------------------------------------
------------------------------ Enumeration state --------------------------
---------------------------------------------------------------------------

-----------------------
------- Suspended constraints
-----------------------

data SuspendedConstraint = SuspendedConstraint !PathTrie !UVar
  deriving ( SuspendedConstraint -> SuspendedConstraint -> Bool
(SuspendedConstraint -> SuspendedConstraint -> Bool)
-> (SuspendedConstraint -> SuspendedConstraint -> Bool)
-> Eq SuspendedConstraint
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SuspendedConstraint -> SuspendedConstraint -> Bool
$c/= :: SuspendedConstraint -> SuspendedConstraint -> Bool
== :: SuspendedConstraint -> SuspendedConstraint -> Bool
$c== :: SuspendedConstraint -> SuspendedConstraint -> Bool
Eq, Eq SuspendedConstraint
Eq SuspendedConstraint
-> (SuspendedConstraint -> SuspendedConstraint -> Ordering)
-> (SuspendedConstraint -> SuspendedConstraint -> Bool)
-> (SuspendedConstraint -> SuspendedConstraint -> Bool)
-> (SuspendedConstraint -> SuspendedConstraint -> Bool)
-> (SuspendedConstraint -> SuspendedConstraint -> Bool)
-> (SuspendedConstraint
    -> SuspendedConstraint -> SuspendedConstraint)
-> (SuspendedConstraint
    -> SuspendedConstraint -> SuspendedConstraint)
-> Ord SuspendedConstraint
SuspendedConstraint -> SuspendedConstraint -> Bool
SuspendedConstraint -> SuspendedConstraint -> Ordering
SuspendedConstraint -> SuspendedConstraint -> SuspendedConstraint
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SuspendedConstraint -> SuspendedConstraint -> SuspendedConstraint
$cmin :: SuspendedConstraint -> SuspendedConstraint -> SuspendedConstraint
max :: SuspendedConstraint -> SuspendedConstraint -> SuspendedConstraint
$cmax :: SuspendedConstraint -> SuspendedConstraint -> SuspendedConstraint
>= :: SuspendedConstraint -> SuspendedConstraint -> Bool
$c>= :: SuspendedConstraint -> SuspendedConstraint -> Bool
> :: SuspendedConstraint -> SuspendedConstraint -> Bool
$c> :: SuspendedConstraint -> SuspendedConstraint -> Bool
<= :: SuspendedConstraint -> SuspendedConstraint -> Bool
$c<= :: SuspendedConstraint -> SuspendedConstraint -> Bool
< :: SuspendedConstraint -> SuspendedConstraint -> Bool
$c< :: SuspendedConstraint -> SuspendedConstraint -> Bool
compare :: SuspendedConstraint -> SuspendedConstraint -> Ordering
$ccompare :: SuspendedConstraint -> SuspendedConstraint -> Ordering
$cp1Ord :: Eq SuspendedConstraint
Ord, Int -> SuspendedConstraint -> ShowS
[SuspendedConstraint] -> ShowS
SuspendedConstraint -> String
(Int -> SuspendedConstraint -> ShowS)
-> (SuspendedConstraint -> String)
-> ([SuspendedConstraint] -> ShowS)
-> Show SuspendedConstraint
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SuspendedConstraint] -> ShowS
$cshowList :: [SuspendedConstraint] -> ShowS
show :: SuspendedConstraint -> String
$cshow :: SuspendedConstraint -> String
showsPrec :: Int -> SuspendedConstraint -> ShowS
$cshowsPrec :: Int -> SuspendedConstraint -> ShowS
Show )

scGetPathTrie :: SuspendedConstraint -> PathTrie
scGetPathTrie :: SuspendedConstraint -> PathTrie
scGetPathTrie (SuspendedConstraint PathTrie
pt UVar
_) = PathTrie
pt

scGetUVar :: SuspendedConstraint -> UVar
scGetUVar :: SuspendedConstraint -> UVar
scGetUVar (SuspendedConstraint PathTrie
_ UVar
uv) = UVar
uv

descendScs :: Int -> Seq SuspendedConstraint -> Seq SuspendedConstraint
descendScs :: Int -> Seq SuspendedConstraint -> Seq SuspendedConstraint
descendScs Int
i Seq SuspendedConstraint
scs = (SuspendedConstraint -> Bool)
-> Seq SuspendedConstraint -> Seq SuspendedConstraint
forall a. (a -> Bool) -> Seq a -> Seq a
Sequence.filter (Bool -> Bool
not (Bool -> Bool)
-> (SuspendedConstraint -> Bool) -> SuspendedConstraint -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PathTrie -> Bool
isEmptyPathTrie (PathTrie -> Bool)
-> (SuspendedConstraint -> PathTrie) -> SuspendedConstraint -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SuspendedConstraint -> PathTrie
scGetPathTrie)
                   (Seq SuspendedConstraint -> Seq SuspendedConstraint)
-> Seq SuspendedConstraint -> Seq SuspendedConstraint
forall a b. (a -> b) -> a -> b
$ (SuspendedConstraint -> SuspendedConstraint)
-> Seq SuspendedConstraint -> Seq SuspendedConstraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(SuspendedConstraint PathTrie
pt UVar
uv) -> PathTrie -> UVar -> SuspendedConstraint
SuspendedConstraint (PathTrie -> Int -> PathTrie
pathTrieDescend PathTrie
pt Int
i) UVar
uv)
                          Seq SuspendedConstraint
scs

-----------------------
------- UVarValue
-----------------------

data UVarValue = UVarUnenumerated { UVarValue -> Maybe Node
contents    :: !(Maybe Node)
                                  , UVarValue -> Seq SuspendedConstraint
constraints :: !(Seq SuspendedConstraint)
                                  }
               | UVarEnumerated   { UVarValue -> TermFragment
termFragment :: !TermFragment }
               | UVarEliminated
  deriving ( UVarValue -> UVarValue -> Bool
(UVarValue -> UVarValue -> Bool)
-> (UVarValue -> UVarValue -> Bool) -> Eq UVarValue
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UVarValue -> UVarValue -> Bool
$c/= :: UVarValue -> UVarValue -> Bool
== :: UVarValue -> UVarValue -> Bool
$c== :: UVarValue -> UVarValue -> Bool
Eq, Eq UVarValue
Eq UVarValue
-> (UVarValue -> UVarValue -> Ordering)
-> (UVarValue -> UVarValue -> Bool)
-> (UVarValue -> UVarValue -> Bool)
-> (UVarValue -> UVarValue -> Bool)
-> (UVarValue -> UVarValue -> Bool)
-> (UVarValue -> UVarValue -> UVarValue)
-> (UVarValue -> UVarValue -> UVarValue)
-> Ord UVarValue
UVarValue -> UVarValue -> Bool
UVarValue -> UVarValue -> Ordering
UVarValue -> UVarValue -> UVarValue
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: UVarValue -> UVarValue -> UVarValue
$cmin :: UVarValue -> UVarValue -> UVarValue
max :: UVarValue -> UVarValue -> UVarValue
$cmax :: UVarValue -> UVarValue -> UVarValue
>= :: UVarValue -> UVarValue -> Bool
$c>= :: UVarValue -> UVarValue -> Bool
> :: UVarValue -> UVarValue -> Bool
$c> :: UVarValue -> UVarValue -> Bool
<= :: UVarValue -> UVarValue -> Bool
$c<= :: UVarValue -> UVarValue -> Bool
< :: UVarValue -> UVarValue -> Bool
$c< :: UVarValue -> UVarValue -> Bool
compare :: UVarValue -> UVarValue -> Ordering
$ccompare :: UVarValue -> UVarValue -> Ordering
$cp1Ord :: Eq UVarValue
Ord, Int -> UVarValue -> ShowS
[UVarValue] -> ShowS
UVarValue -> String
(Int -> UVarValue -> ShowS)
-> (UVarValue -> String)
-> ([UVarValue] -> ShowS)
-> Show UVarValue
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UVarValue] -> ShowS
$cshowList :: [UVarValue] -> ShowS
show :: UVarValue -> String
$cshow :: UVarValue -> String
showsPrec :: Int -> UVarValue -> ShowS
$cshowsPrec :: Int -> UVarValue -> ShowS
Show )

intersectUVarValue :: UVarValue -> UVarValue -> UVarValue
intersectUVarValue :: UVarValue -> UVarValue -> UVarValue
intersectUVarValue (UVarUnenumerated Maybe Node
mn1 Seq SuspendedConstraint
scs1) (UVarUnenumerated Maybe Node
mn2 Seq SuspendedConstraint
scs2) =
  let newContents :: Maybe Node
newContents = case (Maybe Node
mn1, Maybe Node
mn2) of
                      (Maybe Node
Nothing, Maybe Node
x      ) -> Maybe Node
x
                      (Maybe Node
x,       Maybe Node
Nothing) -> Maybe Node
x
                      (Just Node
n1, Just Node
n2) -> Node -> Maybe Node
forall a. a -> Maybe a
Just (Node -> Node -> Node
intersect Node
n1 Node
n2)
      newConstraints :: Seq SuspendedConstraint
newConstraints = Seq SuspendedConstraint
scs1 Seq SuspendedConstraint
-> Seq SuspendedConstraint -> Seq SuspendedConstraint
forall a. Semigroup a => a -> a -> a
<> Seq SuspendedConstraint
scs2
  in Maybe Node -> Seq SuspendedConstraint -> UVarValue
UVarUnenumerated Maybe Node
newContents Seq SuspendedConstraint
newConstraints

intersectUVarValue UVarValue
UVarEliminated            UVarValue
_                         = String -> UVarValue
forall a. HasCallStack => String -> a
error String
"intersectUVarValue: Unexpected UVarEliminated"
intersectUVarValue UVarValue
_                         UVarValue
UVarEliminated            = String -> UVarValue
forall a. HasCallStack => String -> a
error String
"intersectUVarValue: Unexpected UVarEliminated"
intersectUVarValue UVarValue
_                         UVarValue
_                         = String -> UVarValue
forall a. HasCallStack => String -> a
error String
"intersectUVarValue: Intersecting with enumerated value not implemented"


-----------------------
------- Top-level state
-----------------------

data EnumerationState = EnumerationState {
    EnumerationState -> UVarGen
_uvarCounter        :: UVarGen
  , EnumerationState -> UnionFind
_uvarRepresentative :: UnionFind
  , EnumerationState -> Seq UVarValue
_uvarValues         :: Seq UVarValue
  }
  deriving ( EnumerationState -> EnumerationState -> Bool
(EnumerationState -> EnumerationState -> Bool)
-> (EnumerationState -> EnumerationState -> Bool)
-> Eq EnumerationState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EnumerationState -> EnumerationState -> Bool
$c/= :: EnumerationState -> EnumerationState -> Bool
== :: EnumerationState -> EnumerationState -> Bool
$c== :: EnumerationState -> EnumerationState -> Bool
Eq, Eq EnumerationState
Eq EnumerationState
-> (EnumerationState -> EnumerationState -> Ordering)
-> (EnumerationState -> EnumerationState -> Bool)
-> (EnumerationState -> EnumerationState -> Bool)
-> (EnumerationState -> EnumerationState -> Bool)
-> (EnumerationState -> EnumerationState -> Bool)
-> (EnumerationState -> EnumerationState -> EnumerationState)
-> (EnumerationState -> EnumerationState -> EnumerationState)
-> Ord EnumerationState
EnumerationState -> EnumerationState -> Bool
EnumerationState -> EnumerationState -> Ordering
EnumerationState -> EnumerationState -> EnumerationState
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EnumerationState -> EnumerationState -> EnumerationState
$cmin :: EnumerationState -> EnumerationState -> EnumerationState
max :: EnumerationState -> EnumerationState -> EnumerationState
$cmax :: EnumerationState -> EnumerationState -> EnumerationState
>= :: EnumerationState -> EnumerationState -> Bool
$c>= :: EnumerationState -> EnumerationState -> Bool
> :: EnumerationState -> EnumerationState -> Bool
$c> :: EnumerationState -> EnumerationState -> Bool
<= :: EnumerationState -> EnumerationState -> Bool
$c<= :: EnumerationState -> EnumerationState -> Bool
< :: EnumerationState -> EnumerationState -> Bool
$c< :: EnumerationState -> EnumerationState -> Bool
compare :: EnumerationState -> EnumerationState -> Ordering
$ccompare :: EnumerationState -> EnumerationState -> Ordering
$cp1Ord :: Eq EnumerationState
Ord, Int -> EnumerationState -> ShowS
[EnumerationState] -> ShowS
EnumerationState -> String
(Int -> EnumerationState -> ShowS)
-> (EnumerationState -> String)
-> ([EnumerationState] -> ShowS)
-> Show EnumerationState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EnumerationState] -> ShowS
$cshowList :: [EnumerationState] -> ShowS
show :: EnumerationState -> String
$cshow :: EnumerationState -> String
showsPrec :: Int -> EnumerationState -> ShowS
$cshowsPrec :: Int -> EnumerationState -> ShowS
Show )

makeLenses ''EnumerationState


initEnumerationState :: Node -> EnumerationState
initEnumerationState :: Node -> EnumerationState
initEnumerationState Node
n = let (UVarGen
uvg, UVar
uv) = UVarGen -> (UVarGen, UVar)
UnionFind.nextUVar UVarGen
UnionFind.initUVarGen
                         in UVarGen -> UnionFind -> Seq UVarValue -> EnumerationState
EnumerationState UVarGen
uvg
                                             ([UVar] -> UnionFind
UnionFind.withInitialValues [UVar
uv])
                                             (UVarValue -> Seq UVarValue
forall a. a -> Seq a
Sequence.singleton (Maybe Node -> Seq SuspendedConstraint -> UVarValue
UVarUnenumerated (Node -> Maybe Node
forall a. a -> Maybe a
Just Node
n) Seq SuspendedConstraint
forall a. Seq a
Sequence.Empty))



---------------------------------------------------------------------------
---------------------------- Enumeration monad ----------------------------
---------------------------------------------------------------------------

---------------------
-------- Monad
---------------------


type EnumerateM = StateT EnumerationState []

runEnumerateM :: EnumerateM a -> EnumerationState -> [(a, EnumerationState)]
runEnumerateM :: EnumerateM a -> EnumerationState -> [(a, EnumerationState)]
runEnumerateM = EnumerateM a -> EnumerationState -> [(a, EnumerationState)]
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT


---------------------
-------- UVar accessors
---------------------

nextUVar :: EnumerateM UVar
nextUVar :: EnumerateM UVar
nextUVar = do UVarGen
c <- Getting UVarGen EnumerationState UVarGen
-> StateT EnumerationState [] UVarGen
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting UVarGen EnumerationState UVarGen
Lens' EnumerationState UVarGen
uvarCounter
              let (UVarGen
c', UVar
uv) = UVarGen -> (UVarGen, UVar)
UnionFind.nextUVar UVarGen
c
              (UVarGen -> Identity UVarGen)
-> EnumerationState -> Identity EnumerationState
Lens' EnumerationState UVarGen
uvarCounter ((UVarGen -> Identity UVarGen)
 -> EnumerationState -> Identity EnumerationState)
-> UVarGen -> StateT EnumerationState [] ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= UVarGen
c'
              UVar -> EnumerateM UVar
forall (m :: * -> *) a. Monad m => a -> m a
return UVar
uv

addUVarValue :: Maybe Node -> EnumerateM UVar
addUVarValue :: Maybe Node -> EnumerateM UVar
addUVarValue Maybe Node
x = do UVar
uv <- EnumerateM UVar
nextUVar
                    (Seq UVarValue -> Identity (Seq UVarValue))
-> EnumerationState -> Identity EnumerationState
Lens' EnumerationState (Seq UVarValue)
uvarValues ((Seq UVarValue -> Identity (Seq UVarValue))
 -> EnumerationState -> Identity EnumerationState)
-> (Seq UVarValue -> Seq UVarValue)
-> StateT EnumerationState [] ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= (Seq UVarValue -> UVarValue -> Seq UVarValue
forall a. Seq a -> a -> Seq a
:|> (Maybe Node -> Seq SuspendedConstraint -> UVarValue
UVarUnenumerated Maybe Node
x Seq SuspendedConstraint
forall a. Seq a
Sequence.Empty))
                    UVar -> EnumerateM UVar
forall (m :: * -> *) a. Monad m => a -> m a
return UVar
uv

getUVarValue :: UVar -> EnumerateM UVarValue
getUVarValue :: UVar -> EnumerateM UVarValue
getUVarValue UVar
uv = do UVar
uv' <- UVar -> EnumerateM UVar
getUVarRepresentative UVar
uv
                     let idx :: Int
idx = UVar -> Int
uvarToInt UVar
uv'
                     Seq UVarValue
values <- Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
-> StateT EnumerationState [] (Seq UVarValue)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
Lens' EnumerationState (Seq UVarValue)
uvarValues
                     UVarValue -> EnumerateM UVarValue
forall (m :: * -> *) a. Monad m => a -> m a
return (UVarValue -> EnumerateM UVarValue)
-> UVarValue -> EnumerateM UVarValue
forall a b. (a -> b) -> a -> b
$ Seq UVarValue -> Int -> UVarValue
forall a. Seq a -> Int -> a
Sequence.index Seq UVarValue
values Int
idx

getTermFragForUVar :: UVar -> EnumerateM TermFragment
getTermFragForUVar :: UVar -> EnumerateM TermFragment
getTermFragForUVar UVar
uv =  UVarValue -> TermFragment
termFragment (UVarValue -> TermFragment)
-> EnumerateM UVarValue -> EnumerateM TermFragment
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UVar -> EnumerateM UVarValue
getUVarValue UVar
uv

getUVarRepresentative :: UVar -> EnumerateM UVar
getUVarRepresentative :: UVar -> EnumerateM UVar
getUVarRepresentative UVar
uv = do UnionFind
uf <- Getting UnionFind EnumerationState UnionFind
-> StateT EnumerationState [] UnionFind
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting UnionFind EnumerationState UnionFind
Lens' EnumerationState UnionFind
uvarRepresentative
                              let (UVar
uv', UnionFind
uf') = UVar -> UnionFind -> (UVar, UnionFind)
UnionFind.find UVar
uv UnionFind
uf
                              (UnionFind -> Identity UnionFind)
-> EnumerationState -> Identity EnumerationState
Lens' EnumerationState UnionFind
uvarRepresentative ((UnionFind -> Identity UnionFind)
 -> EnumerationState -> Identity EnumerationState)
-> UnionFind -> StateT EnumerationState [] ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= UnionFind
uf'
                              UVar -> EnumerateM UVar
forall (m :: * -> *) a. Monad m => a -> m a
return UVar
uv'

---------------------
-------- Creating UVar's
---------------------

pecToSuspendedConstraint :: PathEClass -> EnumerateM SuspendedConstraint
pecToSuspendedConstraint :: PathEClass -> EnumerateM SuspendedConstraint
pecToSuspendedConstraint PathEClass
pec = do UVar
uv <- Maybe Node -> EnumerateM UVar
addUVarValue Maybe Node
forall a. Maybe a
Nothing
                                  SuspendedConstraint -> EnumerateM SuspendedConstraint
forall (m :: * -> *) a. Monad m => a -> m a
return (SuspendedConstraint -> EnumerateM SuspendedConstraint)
-> SuspendedConstraint -> EnumerateM SuspendedConstraint
forall a b. (a -> b) -> a -> b
$ PathTrie -> UVar -> SuspendedConstraint
SuspendedConstraint (PathEClass -> PathTrie
getPathTrie PathEClass
pec) UVar
uv


---------------------
-------- Merging UVar's / nodes
---------------------

assimilateUvarVal :: UVar -> UVar -> EnumerateM ()
assimilateUvarVal :: UVar -> UVar -> StateT EnumerationState [] ()
assimilateUvarVal UVar
uvTarg UVar
uvSrc
                                | UVar
uvTarg UVar -> UVar -> Bool
forall a. Eq a => a -> a -> Bool
== UVar
uvSrc      = () -> StateT EnumerationState [] ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                                | Bool
otherwise            = do
  Seq UVarValue
values <- Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
-> StateT EnumerationState [] (Seq UVarValue)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
Lens' EnumerationState (Seq UVarValue)
uvarValues
  let srcVal :: UVarValue
srcVal  = Seq UVarValue -> Int -> UVarValue
forall a. Seq a -> Int -> a
Sequence.index Seq UVarValue
values (UVar -> Int
uvarToInt UVar
uvSrc)
  let targVal :: UVarValue
targVal = Seq UVarValue -> Int -> UVarValue
forall a. Seq a -> Int -> a
Sequence.index Seq UVarValue
values (UVar -> Int
uvarToInt UVar
uvTarg)
  case UVarValue
srcVal of
    UVarValue
UVarEliminated -> () -> StateT EnumerationState [] ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Happens from duplicate constraints
    UVarValue
_              -> do
      let v :: UVarValue
v = UVarValue -> UVarValue -> UVarValue
intersectUVarValue UVarValue
srcVal UVarValue
targVal
      Bool -> StateT EnumerationState [] ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (UVarValue -> Maybe Node
contents UVarValue
v Maybe Node -> Maybe Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Node -> Maybe Node
forall a. a -> Maybe a
Just Node
EmptyNode)
      (Seq UVarValue -> Identity (Seq UVarValue))
-> EnumerationState -> Identity EnumerationState
Lens' EnumerationState (Seq UVarValue)
uvarValues((Seq UVarValue -> Identity (Seq UVarValue))
 -> EnumerationState -> Identity EnumerationState)
-> ((UVarValue -> Identity UVarValue)
    -> Seq UVarValue -> Identity (Seq UVarValue))
-> (UVarValue -> Identity UVarValue)
-> EnumerationState
-> Identity EnumerationState
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Index (Seq UVarValue)
-> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue))
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix (Index (Seq UVarValue)
 -> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue)))
-> Index (Seq UVarValue)
-> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue))
forall a b. (a -> b) -> a -> b
$ UVar -> Int
uvarToInt UVar
uvTarg) ((UVarValue -> Identity UVarValue)
 -> EnumerationState -> Identity EnumerationState)
-> UVarValue -> StateT EnumerationState [] ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= UVarValue
v
      (Seq UVarValue -> Identity (Seq UVarValue))
-> EnumerationState -> Identity EnumerationState
Lens' EnumerationState (Seq UVarValue)
uvarValues((Seq UVarValue -> Identity (Seq UVarValue))
 -> EnumerationState -> Identity EnumerationState)
-> ((UVarValue -> Identity UVarValue)
    -> Seq UVarValue -> Identity (Seq UVarValue))
-> (UVarValue -> Identity UVarValue)
-> EnumerationState
-> Identity EnumerationState
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Index (Seq UVarValue)
-> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue))
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix (Index (Seq UVarValue)
 -> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue)))
-> Index (Seq UVarValue)
-> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue))
forall a b. (a -> b) -> a -> b
$ UVar -> Int
uvarToInt UVar
uvSrc)  ((UVarValue -> Identity UVarValue)
 -> EnumerationState -> Identity EnumerationState)
-> UVarValue -> StateT EnumerationState [] ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= UVarValue
UVarEliminated


mergeNodeIntoUVarVal :: UVar -> Node -> Seq SuspendedConstraint -> EnumerateM ()
mergeNodeIntoUVarVal :: UVar
-> Node -> Seq SuspendedConstraint -> StateT EnumerationState [] ()
mergeNodeIntoUVarVal UVar
uv Node
n Seq SuspendedConstraint
scs = do
  UVar
uv' <- UVar -> EnumerateM UVar
getUVarRepresentative UVar
uv
  let idx :: Int
idx = UVar -> Int
uvarToInt UVar
uv'
  (Seq UVarValue -> Identity (Seq UVarValue))
-> EnumerationState -> Identity EnumerationState
Lens' EnumerationState (Seq UVarValue)
uvarValues((Seq UVarValue -> Identity (Seq UVarValue))
 -> EnumerationState -> Identity EnumerationState)
-> ((UVarValue -> Identity UVarValue)
    -> Seq UVarValue -> Identity (Seq UVarValue))
-> (UVarValue -> Identity UVarValue)
-> EnumerationState
-> Identity EnumerationState
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Index (Seq UVarValue)
-> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue))
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index (Seq UVarValue)
idx) ((UVarValue -> Identity UVarValue)
 -> EnumerationState -> Identity EnumerationState)
-> (UVarValue -> UVarValue) -> StateT EnumerationState [] ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= UVarValue -> UVarValue -> UVarValue
intersectUVarValue (Maybe Node -> Seq SuspendedConstraint -> UVarValue
UVarUnenumerated (Node -> Maybe Node
forall a. a -> Maybe a
Just Node
n) Seq SuspendedConstraint
scs)
  Seq UVarValue
newValues <- Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
-> StateT EnumerationState [] (Seq UVarValue)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
Lens' EnumerationState (Seq UVarValue)
uvarValues
  Bool -> StateT EnumerationState [] ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (UVarValue -> Maybe Node
contents (Seq UVarValue -> Int -> UVarValue
forall a. Seq a -> Int -> a
Sequence.index Seq UVarValue
newValues Int
idx) Maybe Node -> Maybe Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Node -> Maybe Node
forall a. a -> Maybe a
Just Node
EmptyNode)


---------------------
-------- Variant maintainer
---------------------

-- This thing here might be a performance issue. UPDATE: Yes it is; clocked at 1/3 the time and 1/2 the
-- allocations of enumerateFully
--
-- It exists because it was easier to code / might actually be faster
-- to update referenced uvars here than inline in firstExpandableUVar.
-- There is no Sequence.foldMapWithIndexM.
refreshReferencedUVars :: EnumerateM ()
refreshReferencedUVars :: StateT EnumerationState [] ()
refreshReferencedUVars = do
  Seq UVarValue
values <- Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
-> StateT EnumerationState [] (Seq UVarValue)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
Lens' EnumerationState (Seq UVarValue)
uvarValues
  Seq UVarValue
updated <- (UVarValue -> EnumerateM UVarValue)
-> Seq UVarValue -> StateT EnumerationState [] (Seq UVarValue)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\case UVarUnenumerated Maybe Node
n Seq SuspendedConstraint
scs ->
                               Maybe Node -> Seq SuspendedConstraint -> UVarValue
UVarUnenumerated Maybe Node
n (Seq SuspendedConstraint -> UVarValue)
-> StateT EnumerationState [] (Seq SuspendedConstraint)
-> EnumerateM UVarValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                                   (SuspendedConstraint -> EnumerateM SuspendedConstraint)
-> Seq SuspendedConstraint
-> StateT EnumerationState [] (Seq SuspendedConstraint)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\SuspendedConstraint
sc -> PathTrie -> UVar -> SuspendedConstraint
SuspendedConstraint (SuspendedConstraint -> PathTrie
scGetPathTrie SuspendedConstraint
sc)
                                                                       (UVar -> SuspendedConstraint)
-> EnumerateM UVar -> EnumerateM SuspendedConstraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UVar -> EnumerateM UVar
getUVarRepresentative (SuspendedConstraint -> UVar
scGetUVar SuspendedConstraint
sc))
                                        Seq SuspendedConstraint
scs

                             UVarValue
x                      -> UVarValue -> EnumerateM UVarValue
forall (m :: * -> *) a. Monad m => a -> m a
return UVarValue
x)
                      Seq UVarValue
values

  (Seq UVarValue -> Identity (Seq UVarValue))
-> EnumerationState -> Identity EnumerationState
Lens' EnumerationState (Seq UVarValue)
uvarValues ((Seq UVarValue -> Identity (Seq UVarValue))
 -> EnumerationState -> Identity EnumerationState)
-> Seq UVarValue -> StateT EnumerationState [] ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= Seq UVarValue
updated


---------------------
-------- Core enumeration algorithm
---------------------

enumerateNode :: Seq SuspendedConstraint -> Node -> EnumerateM TermFragment
enumerateNode :: Seq SuspendedConstraint -> Node -> EnumerateM TermFragment
enumerateNode Seq SuspendedConstraint
_   Node
EmptyNode = EnumerateM TermFragment
forall (m :: * -> *) a. MonadPlus m => m a
mzero
enumerateNode Seq SuspendedConstraint
scs Node
n         =
  let (Seq SuspendedConstraint
hereConstraints, Seq SuspendedConstraint
descendantConstraints) = (SuspendedConstraint -> Bool)
-> Seq SuspendedConstraint
-> (Seq SuspendedConstraint, Seq SuspendedConstraint)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Sequence.partition (\(SuspendedConstraint PathTrie
pt UVar
_) -> PathTrie -> Bool
isTerminalPathTrie PathTrie
pt) Seq SuspendedConstraint
scs
  in case Seq SuspendedConstraint
hereConstraints of
       Seq SuspendedConstraint
Sequence.Empty -> case Node
n of
                           Mu Node -> Node
_    -> UVar -> TermFragment
TermFragmentUVar (UVar -> TermFragment)
-> EnumerateM UVar -> EnumerateM TermFragment
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Node -> EnumerateM UVar
addUVarValue (Node -> Maybe Node
forall a. a -> Maybe a
Just Node
n)
                           Node [Edge]
es -> Seq SuspendedConstraint -> Edge -> EnumerateM TermFragment
enumerateEdge Seq SuspendedConstraint
scs (Edge -> EnumerateM TermFragment)
-> StateT EnumerationState [] Edge -> EnumerateM TermFragment
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Edge] -> StateT EnumerationState [] Edge
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [Edge]
es
                           Node
_       -> String -> EnumerateM TermFragment
forall a. HasCallStack => String -> a
error (String -> EnumerateM TermFragment)
-> String -> EnumerateM TermFragment
forall a b. (a -> b) -> a -> b
$ String
"enumerateNode: unexpected node " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Node -> String
forall a. Show a => a -> String
show Node
n

       (SuspendedConstraint
x :<| Seq SuspendedConstraint
xs)     -> do Seq UVar
reps <- (SuspendedConstraint -> EnumerateM UVar)
-> Seq SuspendedConstraint -> StateT EnumerationState [] (Seq UVar)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (UVar -> EnumerateM UVar
getUVarRepresentative (UVar -> EnumerateM UVar)
-> (SuspendedConstraint -> UVar)
-> SuspendedConstraint
-> EnumerateM UVar
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SuspendedConstraint -> UVar
scGetUVar) Seq SuspendedConstraint
hereConstraints
                            Seq SuspendedConstraint
-> (SuspendedConstraint -> StateT EnumerationState [] ())
-> StateT EnumerationState [] ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Seq SuspendedConstraint
xs ((SuspendedConstraint -> StateT EnumerationState [] ())
 -> StateT EnumerationState [] ())
-> (SuspendedConstraint -> StateT EnumerationState [] ())
-> StateT EnumerationState [] ()
forall a b. (a -> b) -> a -> b
$ \SuspendedConstraint
sc -> (UnionFind -> Identity UnionFind)
-> EnumerationState -> Identity EnumerationState
Lens' EnumerationState UnionFind
uvarRepresentative ((UnionFind -> Identity UnionFind)
 -> EnumerationState -> Identity EnumerationState)
-> (UnionFind -> UnionFind) -> StateT EnumerationState [] ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= UVar -> UVar -> UnionFind -> UnionFind
UnionFind.union (SuspendedConstraint -> UVar
scGetUVar SuspendedConstraint
x) (SuspendedConstraint -> UVar
scGetUVar SuspendedConstraint
sc)
                            UVar
uv <- UVar -> EnumerateM UVar
getUVarRepresentative (SuspendedConstraint -> UVar
scGetUVar SuspendedConstraint
x)
                            (UVar -> StateT EnumerationState [] ())
-> Seq UVar -> StateT EnumerationState [] ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (UVar -> UVar -> StateT EnumerationState [] ()
assimilateUvarVal UVar
uv) Seq UVar
reps

                            UVar
-> Node -> Seq SuspendedConstraint -> StateT EnumerationState [] ()
mergeNodeIntoUVarVal UVar
uv Node
n Seq SuspendedConstraint
descendantConstraints
                            TermFragment -> EnumerateM TermFragment
forall (m :: * -> *) a. Monad m => a -> m a
return (TermFragment -> EnumerateM TermFragment)
-> TermFragment -> EnumerateM TermFragment
forall a b. (a -> b) -> a -> b
$ UVar -> TermFragment
TermFragmentUVar UVar
uv

enumerateEdge :: Seq SuspendedConstraint -> Edge -> EnumerateM TermFragment
enumerateEdge :: Seq SuspendedConstraint -> Edge -> EnumerateM TermFragment
enumerateEdge Seq SuspendedConstraint
scs Edge
e = do
  let highestConstraintIndex :: Int
highestConstraintIndex = Max Int -> Int
forall a. Max a -> a
getMax (Max Int -> Int) -> Max Int -> Int
forall a b. (a -> b) -> a -> b
$ (SuspendedConstraint -> Max Int)
-> Seq SuspendedConstraint -> Max Int
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\SuspendedConstraint
sc -> Int -> Max Int
forall a. a -> Max a
Max (Int -> Max Int) -> Int -> Max Int
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe (-Int
1) (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ PathTrie -> Maybe Int
getMaxNonemptyIndex (PathTrie -> Maybe Int) -> PathTrie -> Maybe Int
forall a b. (a -> b) -> a -> b
$ SuspendedConstraint -> PathTrie
scGetPathTrie SuspendedConstraint
sc) Seq SuspendedConstraint
scs
  Bool -> StateT EnumerationState [] ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> StateT EnumerationState [] ())
-> Bool -> StateT EnumerationState [] ()
forall a b. (a -> b) -> a -> b
$ Int
highestConstraintIndex Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [Node] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Edge -> [Node]
edgeChildren Edge
e)

  Seq SuspendedConstraint
newScs <- [SuspendedConstraint] -> Seq SuspendedConstraint
forall a. [a] -> Seq a
Sequence.fromList ([SuspendedConstraint] -> Seq SuspendedConstraint)
-> StateT EnumerationState [] [SuspendedConstraint]
-> StateT EnumerationState [] (Seq SuspendedConstraint)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PathEClass -> EnumerateM SuspendedConstraint)
-> [PathEClass] -> StateT EnumerationState [] [SuspendedConstraint]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PathEClass -> EnumerateM SuspendedConstraint
pecToSuspendedConstraint (EqConstraints -> [PathEClass]
unsafeGetEclasses (EqConstraints -> [PathEClass]) -> EqConstraints -> [PathEClass]
forall a b. (a -> b) -> a -> b
$ Edge -> EqConstraints
edgeEcs Edge
e)
  let scs' :: Seq SuspendedConstraint
scs' = Seq SuspendedConstraint
scs Seq SuspendedConstraint
-> Seq SuspendedConstraint -> Seq SuspendedConstraint
forall a. Semigroup a => a -> a -> a
<> Seq SuspendedConstraint
newScs
  Symbol -> [TermFragment] -> TermFragment
TermFragmentNode (Edge -> Symbol
edgeSymbol Edge
e) ([TermFragment] -> TermFragment)
-> StateT EnumerationState [] [TermFragment]
-> EnumerateM TermFragment
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> Node -> EnumerateM TermFragment)
-> [Node] -> StateT EnumerationState [] [TermFragment]
forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> [a] -> m [b]
imapM (\Int
i Node
n -> Seq SuspendedConstraint -> Node -> EnumerateM TermFragment
enumerateNode (Int -> Seq SuspendedConstraint -> Seq SuspendedConstraint
descendScs Int
i Seq SuspendedConstraint
scs') Node
n) (Edge -> [Node]
edgeChildren Edge
e)


---------------------
-------- Enumeration-loop control
---------------------

data ExpandableUVarResult = ExpansionStuck | ExpansionDone | ExpansionNext !UVar

-- Can speed this up with bitvectors
firstExpandableUVar :: EnumerateM ExpandableUVarResult
firstExpandableUVar :: EnumerateM ExpandableUVarResult
firstExpandableUVar = do
    Seq UVarValue
values <- Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
-> StateT EnumerationState [] (Seq UVarValue)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting (Seq UVarValue) EnumerationState (Seq UVarValue)
Lens' EnumerationState (Seq UVarValue)
uvarValues
    -- check representative uvars because only representatives are updated
    [IntMap Any]
candidateMaps <- (Int -> StateT EnumerationState [] (IntMap Any))
-> [Int] -> StateT EnumerationState [] [IntMap Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Int
i -> do UVar
rep <- UVar -> EnumerateM UVar
getUVarRepresentative (Int -> UVar
intToUVar Int
i)
                                    UVarValue
v <- UVar -> EnumerateM UVarValue
getUVarValue UVar
rep
                                    case UVarValue
v of
                                        (UVarUnenumerated (Just (Mu Node -> Node
_)) Seq SuspendedConstraint
Sequence.Empty) -> IntMap Any -> StateT EnumerationState [] (IntMap Any)
forall (m :: * -> *) a. Monad m => a -> m a
return IntMap Any
forall a. IntMap a
IntMap.empty
                                        (UVarUnenumerated (Just (Mu Node -> Node
_)) Seq SuspendedConstraint
_             ) -> IntMap Any -> StateT EnumerationState [] (IntMap Any)
forall (m :: * -> *) a. Monad m => a -> m a
return (IntMap Any -> StateT EnumerationState [] (IntMap Any))
-> IntMap Any -> StateT EnumerationState [] (IntMap Any)
forall a b. (a -> b) -> a -> b
$ Int -> Any -> IntMap Any
forall a. Int -> a -> IntMap a
IntMap.singleton (UVar -> Int
uvarToInt UVar
rep) (Bool -> Any
Any Bool
False)
                                        (UVarUnenumerated (Just Node
_)      Seq SuspendedConstraint
_)              -> IntMap Any -> StateT EnumerationState [] (IntMap Any)
forall (m :: * -> *) a. Monad m => a -> m a
return (IntMap Any -> StateT EnumerationState [] (IntMap Any))
-> IntMap Any -> StateT EnumerationState [] (IntMap Any)
forall a b. (a -> b) -> a -> b
$ Int -> Any -> IntMap Any
forall a. Int -> a -> IntMap a
IntMap.singleton (UVar -> Int
uvarToInt UVar
rep) (Bool -> Any
Any Bool
False)
                                        UVarValue
_                                               -> IntMap Any -> StateT EnumerationState [] (IntMap Any)
forall (m :: * -> *) a. Monad m => a -> m a
return IntMap Any
forall a. IntMap a
IntMap.empty)
                              [Int
0..(Seq UVarValue -> Int
forall a. Seq a -> Int
Sequence.length Seq UVarValue
values Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]
    let candidates :: IntMap Any
candidates = [IntMap Any] -> IntMap Any
forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IntMap.unions [IntMap Any]
candidateMaps

    if IntMap Any -> Bool
forall a. IntMap a -> Bool
IntMap.null IntMap Any
candidates then
      ExpandableUVarResult -> EnumerateM ExpandableUVarResult
forall (m :: * -> *) a. Monad m => a -> m a
return ExpandableUVarResult
ExpansionDone
     else do
      let ruledOut :: IntMap Any
ruledOut = (UVarValue -> IntMap Any) -> Seq UVarValue -> IntMap Any
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
                      (\case (UVarUnenumerated Maybe Node
_ Seq SuspendedConstraint
scs) -> (SuspendedConstraint -> IntMap Any)
-> Seq SuspendedConstraint -> IntMap Any
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
                                                             (\SuspendedConstraint
sc -> Int -> Any -> IntMap Any
forall a. Int -> a -> IntMap a
IntMap.singleton (UVar -> Int
uvarToInt (UVar -> Int) -> UVar -> Int
forall a b. (a -> b) -> a -> b
$ SuspendedConstraint -> UVar
scGetUVar SuspendedConstraint
sc) (Bool -> Any
Any Bool
True))
                                                             Seq SuspendedConstraint
scs

                             UVarValue
_                         -> IntMap Any
forall a. IntMap a
IntMap.empty)
                      Seq UVarValue
values

      let unconstrainedCandidateMap :: IntMap Any
unconstrainedCandidateMap = (Any -> Bool) -> IntMap Any -> IntMap Any
forall a. (a -> Bool) -> IntMap a -> IntMap a
IntMap.filter (Bool -> Bool
not (Bool -> Bool) -> (Any -> Bool) -> Any -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Any -> Bool
getAny) (IntMap Any
ruledOut IntMap Any -> IntMap Any -> IntMap Any
forall a. Semigroup a => a -> a -> a
<> IntMap Any
candidates)
      case IntMap Any -> Maybe (Int, Any)
forall a. IntMap a -> Maybe (Int, a)
IntMap.lookupMin IntMap Any
unconstrainedCandidateMap of
        Maybe (Int, Any)
Nothing     -> ExpandableUVarResult -> EnumerateM ExpandableUVarResult
forall (m :: * -> *) a. Monad m => a -> m a
return ExpandableUVarResult
ExpansionStuck
        Just (Int
i, Any
_) -> ExpandableUVarResult -> EnumerateM ExpandableUVarResult
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpandableUVarResult -> EnumerateM ExpandableUVarResult)
-> ExpandableUVarResult -> EnumerateM ExpandableUVarResult
forall a b. (a -> b) -> a -> b
$ UVar -> ExpandableUVarResult
ExpansionNext (UVar -> ExpandableUVarResult) -> UVar -> ExpandableUVarResult
forall a b. (a -> b) -> a -> b
$ Int -> UVar
intToUVar Int
i



enumerateOutUVar :: UVar -> EnumerateM TermFragment
enumerateOutUVar :: UVar -> EnumerateM TermFragment
enumerateOutUVar UVar
uv = do UVarUnenumerated (Just Node
n) Seq SuspendedConstraint
scs <- UVar -> EnumerateM UVarValue
getUVarValue UVar
uv
                         UVar
uv' <- UVar -> EnumerateM UVar
getUVarRepresentative UVar
uv

                         TermFragment
t <- case Node
n of
                                Mu Node -> Node
_ -> Seq SuspendedConstraint -> Node -> EnumerateM TermFragment
enumerateNode Seq SuspendedConstraint
scs (Node -> Node
unfoldOuterRec Node
n)
                                Node
_    -> Seq SuspendedConstraint -> Node -> EnumerateM TermFragment
enumerateNode Seq SuspendedConstraint
scs Node
n


                         (Seq UVarValue -> Identity (Seq UVarValue))
-> EnumerationState -> Identity EnumerationState
Lens' EnumerationState (Seq UVarValue)
uvarValues((Seq UVarValue -> Identity (Seq UVarValue))
 -> EnumerationState -> Identity EnumerationState)
-> ((UVarValue -> Identity UVarValue)
    -> Seq UVarValue -> Identity (Seq UVarValue))
-> (UVarValue -> Identity UVarValue)
-> EnumerationState
-> Identity EnumerationState
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Index (Seq UVarValue)
-> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue))
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix (Index (Seq UVarValue)
 -> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue)))
-> Index (Seq UVarValue)
-> Traversal' (Seq UVarValue) (IxValue (Seq UVarValue))
forall a b. (a -> b) -> a -> b
$ UVar -> Int
uvarToInt UVar
uv') ((UVarValue -> Identity UVarValue)
 -> EnumerationState -> Identity EnumerationState)
-> UVarValue -> StateT EnumerationState [] ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= TermFragment -> UVarValue
UVarEnumerated TermFragment
t
                         StateT EnumerationState [] ()
refreshReferencedUVars
                         TermFragment -> EnumerateM TermFragment
forall (m :: * -> *) a. Monad m => a -> m a
return TermFragment
t

enumerateOutFirstExpandableUVar :: EnumerateM ()
enumerateOutFirstExpandableUVar :: StateT EnumerationState [] ()
enumerateOutFirstExpandableUVar = do
  ExpandableUVarResult
muv <- EnumerateM ExpandableUVarResult
firstExpandableUVar
  case ExpandableUVarResult
muv of
    ExpansionNext UVar
uv -> EnumerateM TermFragment -> StateT EnumerationState [] ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (EnumerateM TermFragment -> StateT EnumerationState [] ())
-> EnumerateM TermFragment -> StateT EnumerationState [] ()
forall a b. (a -> b) -> a -> b
$ UVar -> EnumerateM TermFragment
enumerateOutUVar UVar
uv
    ExpandableUVarResult
ExpansionDone    -> StateT EnumerationState [] ()
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    ExpandableUVarResult
ExpansionStuck   -> StateT EnumerationState [] ()
forall (m :: * -> *) a. MonadPlus m => m a
mzero

enumerateFully :: EnumerateM ()
enumerateFully :: StateT EnumerationState [] ()
enumerateFully = do
  ExpandableUVarResult
muv <- EnumerateM ExpandableUVarResult
firstExpandableUVar
  case ExpandableUVarResult
muv of
    ExpandableUVarResult
ExpansionStuck   -> StateT EnumerationState [] ()
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    ExpandableUVarResult
ExpansionDone    -> () -> StateT EnumerationState [] ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    ExpansionNext UVar
uv -> do UVarUnenumerated (Just Node
n) Seq SuspendedConstraint
scs <- UVar -> EnumerateM UVarValue
getUVarValue UVar
uv
                           if Seq SuspendedConstraint
scs Seq SuspendedConstraint -> Seq SuspendedConstraint -> Bool
forall a. Eq a => a -> a -> Bool
== Seq SuspendedConstraint
forall a. Seq a
Sequence.Empty then
                             case Node
n of
                               Mu Node -> Node
_ -> () -> StateT EnumerationState [] ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                               Node
_    -> UVar -> EnumerateM TermFragment
enumerateOutUVar UVar
uv EnumerateM TermFragment
-> StateT EnumerationState [] () -> StateT EnumerationState [] ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> StateT EnumerationState [] ()
enumerateFully
                            else
                             UVar -> EnumerateM TermFragment
enumerateOutUVar UVar
uv EnumerateM TermFragment
-> StateT EnumerationState [] () -> StateT EnumerationState [] ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> StateT EnumerationState [] ()
enumerateFully

---------------------
-------- Expanding an enumerated term fragment into a term
---------------------

expandTermFrag :: TermFragment -> EnumerateM Term
expandTermFrag :: TermFragment -> EnumerateM Term
expandTermFrag (TermFragmentNode Symbol
s [TermFragment]
ts) = Symbol -> [Term] -> Term
Term Symbol
s ([Term] -> Term)
-> StateT EnumerationState [] [Term] -> EnumerateM Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TermFragment -> EnumerateM Term)
-> [TermFragment] -> StateT EnumerationState [] [Term]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TermFragment -> EnumerateM Term
expandTermFrag [TermFragment]
ts
expandTermFrag (TermFragmentUVar UVar
uv)   = do UVarValue
val <- UVar -> EnumerateM UVarValue
getUVarValue UVar
uv
                                            case UVarValue
val of
                                              UVarEnumerated TermFragment
t                 -> TermFragment -> EnumerateM Term
expandTermFrag TermFragment
t
                                              UVarUnenumerated (Just (Mu Node -> Node
_)) Seq SuspendedConstraint
_ -> Term -> EnumerateM Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> EnumerateM Term) -> Term -> EnumerateM Term
forall a b. (a -> b) -> a -> b
$ Symbol -> [Term] -> Term
Term Symbol
"Mu" []
                                              UVarValue
_                                -> String -> EnumerateM Term
forall a. HasCallStack => String -> a
error String
"expandTermFrag: Non-recursive, unenumerated node encountered"

expandUVar :: UVar -> EnumerateM Term
expandUVar :: UVar -> EnumerateM Term
expandUVar UVar
uv = do UVarEnumerated TermFragment
t <- UVar -> EnumerateM UVarValue
getUVarValue UVar
uv
                   TermFragment -> EnumerateM Term
expandTermFrag TermFragment
t


---------------------
-------- Full enumeration
---------------------

getAllTruncatedTerms :: Node -> [Term]
getAllTruncatedTerms :: Node -> [Term]
getAllTruncatedTerms Node
n = ((TermFragment, EnumerationState) -> Term)
-> [(TermFragment, EnumerationState)] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (TermFragment -> Term
termFragToTruncatedTerm (TermFragment -> Term)
-> ((TermFragment, EnumerationState) -> TermFragment)
-> (TermFragment, EnumerationState)
-> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermFragment, EnumerationState) -> TermFragment
forall a b. (a, b) -> a
fst) ([(TermFragment, EnumerationState)] -> [Term])
-> [(TermFragment, EnumerationState)] -> [Term]
forall a b. (a -> b) -> a -> b
$
                         (EnumerateM TermFragment
 -> EnumerationState -> [(TermFragment, EnumerationState)])
-> EnumerationState
-> EnumerateM TermFragment
-> [(TermFragment, EnumerationState)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip EnumerateM TermFragment
-> EnumerationState -> [(TermFragment, EnumerationState)]
forall a.
EnumerateM a -> EnumerationState -> [(a, EnumerationState)]
runEnumerateM (Node -> EnumerationState
initEnumerationState Node
n) (EnumerateM TermFragment -> [(TermFragment, EnumerationState)])
-> EnumerateM TermFragment -> [(TermFragment, EnumerationState)]
forall a b. (a -> b) -> a -> b
$ do
                           StateT EnumerationState [] ()
enumerateFully
                           UVar -> EnumerateM TermFragment
getTermFragForUVar (Int -> UVar
intToUVar Int
0)

getAllTerms :: Node -> [Term]
getAllTerms :: Node -> [Term]
getAllTerms Node
n = ((Term, EnumerationState) -> Term)
-> [(Term, EnumerationState)] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Term, EnumerationState) -> Term
forall a b. (a, b) -> a
fst ([(Term, EnumerationState)] -> [Term])
-> [(Term, EnumerationState)] -> [Term]
forall a b. (a -> b) -> a -> b
$ (EnumerateM Term -> EnumerationState -> [(Term, EnumerationState)])
-> EnumerationState
-> EnumerateM Term
-> [(Term, EnumerationState)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip EnumerateM Term -> EnumerationState -> [(Term, EnumerationState)]
forall a.
EnumerateM a -> EnumerationState -> [(a, EnumerationState)]
runEnumerateM (Node -> EnumerationState
initEnumerationState Node
n) (EnumerateM Term -> [(Term, EnumerationState)])
-> EnumerateM Term -> [(Term, EnumerationState)]
forall a b. (a -> b) -> a -> b
$ do
                  StateT EnumerationState [] ()
enumerateFully
                  UVar -> EnumerateM Term
expandUVar (Int -> UVar
intToUVar Int
0)


-- | Inefficient enumeration
--
-- For ECTAs with 'Mu' nodes may produce an infinite list or may loop indefinitely, depending on the ECTAs. For example, for
--
-- > createMu $ \r -> Node [Edge "f" [r], Edge "a" []]
--
-- it will produce
--
-- > [ Term "a" []
-- > , Term "f" [Term "a" []]
-- > , Term "f" [Term "f" [Term "a" []]]
-- > , ...
-- > ]
--
-- This happens to work currently because non-recursive edges are interned before recursive edges.
--
-- TODO: It would be much nicer if this did fair enumeration. It would avoid the beforementioned dependency on interning
-- order, and it would give better enumeration for examples such as
--
-- > Node [Edge "h" [
-- >     createMu $ \r -> Node [Edge "f" [r], Edge "a" []]
-- >   , createMu $ \r -> Node [Edge "g" [r], Edge "b" []]
-- >   ]]
--
-- This will currently produce
--
-- > [ Term "h" [Term "a" [], Term "b" []]
-- > , Term "h" [Term "a" [], Term "g" [Term "b" []]]
-- > , Term "h" [Term "a" [], Term "g" [Term "g" [Term "b" []]]]
-- > , ..
-- > ]
--
-- where it always unfolds the /second/ argument to @h@, never the first.
naiveDenotation :: Node -> [Term]
naiveDenotation :: Node -> [Term]
naiveDenotation = Maybe Int -> Node -> [Term]
naiveDenotationBounded Maybe Int
forall a. Maybe a
Nothing

-- | set a boundary on the depth of Mu node unfolding
-- if the boundary is set to @Just n@, then @n@ levels of Mu node unfolding will be performed
-- if the boundary is set to @Nothing@, then no boundary is set and the Mu nodes will be always unfolded
naiveDenotationBounded :: Maybe Int -> Node -> [Term]
naiveDenotationBounded :: Maybe Int -> Node -> [Term]
naiveDenotationBounded Maybe Int
maxDepth Node
node = Producer Term Identity () -> [Term]
forall a. Producer a Identity () -> [a]
Pipes.toList (Producer Term Identity () -> [Term])
-> Producer Term Identity () -> [Term]
forall a b. (a -> b) -> a -> b
$ ListT Identity Term -> Producer Term Identity ()
forall (m :: * -> *) (t :: (* -> *) -> * -> *) a x' x.
(Monad m, Enumerable t) =>
t m a -> Proxy x' x () a m ()
every (Maybe Int -> Node -> ListT Identity Term
go Maybe Int
maxDepth Node
node)
  where
    -- | Note that this code uses the decision that f(a,a) does not satisfy the constraint 0.0=1.0 because those paths are empty.
    --   It would be equally valid to say that it does.
    ecsSatisfied :: Term -> EqConstraints -> Bool
    ecsSatisfied :: Term -> EqConstraints -> Bool
ecsSatisfied Term
t EqConstraints
ecs = ([Path] -> Bool) -> [[Path]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\[Path]
ps -> Maybe Term -> Bool
forall a. Maybe a -> Bool
isJust (Path -> Term -> Emptyable Term
forall t t'. Pathable t t' => Path -> t -> Emptyable t'
getPath ([Path] -> Path
forall a. [a] -> a
head [Path]
ps) Term
t) Bool -> Bool -> Bool
&& (Path -> Bool) -> [Path] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\Path
p' -> Path -> Term -> Emptyable Term
forall t t'. Pathable t t' => Path -> t -> Emptyable t'
getPath ([Path] -> Path
forall a. [a] -> a
head [Path]
ps) Term
t Maybe Term -> Maybe Term -> Bool
forall a. Eq a => a -> a -> Bool
== Path -> Term -> Emptyable Term
forall t t'. Pathable t t' => Path -> t -> Emptyable t'
getPath Path
p' Term
t) [Path]
ps)
                             ((PathEClass -> [Path]) -> [PathEClass] -> [[Path]]
forall a b. (a -> b) -> [a] -> [b]
map PathEClass -> [Path]
unPathEClass ([PathEClass] -> [[Path]]) -> [PathEClass] -> [[Path]]
forall a b. (a -> b) -> a -> b
$ EqConstraints -> [PathEClass]
unsafeGetEclasses EqConstraints
ecs)

    go :: Maybe Int -> Node -> ListT Identity Term
    go :: Maybe Int -> Node -> ListT Identity Term
go Maybe Int
_       Node
EmptyNode = ListT Identity Term
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    go Maybe Int
mbDepth n :: Node
n@(Mu Node -> Node
_)  = case Maybe Int
mbDepth of
                             Maybe Int
Nothing            -> Maybe Int -> Node -> ListT Identity Term
go Maybe Int
forall a. Maybe a
Nothing (Node -> Node
unfoldOuterRec Node
n)
                             Just Int
d | Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0    -> ListT Identity Term
forall (m :: * -> *) a. MonadPlus m => m a
mzero
                                    | Bool
otherwise -> Maybe Int -> Node -> ListT Identity Term
go (Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Node -> Node
unfoldOuterRec Node
n)
    go Maybe Int
_       (Rec RecNodeId
_)   = String -> ListT Identity Term
forall a. HasCallStack => String -> a
error String
"naiveDenotation: unexpected Rec"
    go Maybe Int
mbDepth (Node [Edge]
es) = do
      Edge
e <- Producer Edge Identity () -> ListT Identity Edge
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select (Producer Edge Identity () -> ListT Identity Edge)
-> Producer Edge Identity () -> ListT Identity Edge
forall a b. (a -> b) -> a -> b
$ [Edge] -> Producer Edge Identity ()
forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [Edge]
es

      [Term]
children <- (Node -> ListT Identity Term) -> [Node] -> ListT Identity [Term]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Maybe Int -> Node -> ListT Identity Term
go Maybe Int
mbDepth) (Edge -> [Node]
edgeChildren Edge
e)

      let res :: Term
res = Symbol -> [Term] -> Term
Term (Edge -> Symbol
edgeSymbol Edge
e) [Term]
children
      Bool -> ListT Identity ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> ListT Identity ()) -> Bool -> ListT Identity ()
forall a b. (a -> b) -> a -> b
$ Term -> EqConstraints -> Bool
ecsSatisfied Term
res (Edge -> EqConstraints
edgeEcs Edge
e)
      Term -> ListT Identity Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
res