{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}

-- For the 'Pathable' instance for 'Node'
{-# OPTIONS_GHC -Wno-orphans #-}

module Data.ECTA.Internal.ECTA.Operations (
  -- * Traversal
    pathsMatching
  , mapNodes
  , crush
  , onNormalNodes

  -- * Unfolding
  , unfoldOuterRec
  , refold
  , nodeEdges
  , unfoldBounded

  -- * Size operations
  , nodeCount
  , edgeCount
  , maxIndegree

  -- * Union
  , union

  -- * Membership
  , nodeRepresents
  , edgeRepresents

  -- * Intersection
  , intersect
  , dropRedundantEdges
  , intersectEdge

  -- * Path operations
  , requirePath
  , requirePathList

  -- * Reduction
  , withoutRedundantEdges
  , reducePartially
  , reduceEdgeIntersection
  , reduceEqConstraints

  -- * Debugging
  , getSubnodeById
  ) where


import Control.Monad.State.Strict ( evalState, State, MonadState(..), modify' )
import Data.Hashable ( hash, Hashable(..) )
import qualified Data.HashMap.Strict as HashMap
import Data.List ( inits, tails )
import Data.Maybe ( catMaybes )
import Data.Monoid ( Sum(..), First(..) )
import Data.Semigroup ( Max(..) )
import           Data.Map.Strict ( Map )
import qualified Data.Map.Strict as Map
import           Data.Set ( Set )
import qualified Data.Set as Set

import Control.Lens ( (&), ix, (^?), (%~) )
import Data.List.Index ( imap )

import Data.ECTA.Internal.ECTA.Type
import Data.ECTA.Internal.Paths
import Data.ECTA.Internal.Term

--   Switch the comments on these lines to switch to ekmett's original `intern` library
--   instead of our single-threaded hashtable-based reimplementation.
import Data.Interned.Extended.HashTableBased ( Id )
-- import Data.Interned ( Interned(..), unintern, Id, Cache, mkCache )
-- import Data.Interned.Extended.SingleThreaded ( intern )

import Data.Memoization ( MemoCacheTag(..), memo, memo2 )
import Utility.Fixpoint
import Utility.HashJoin

------------------------------------------------------------------------------------


-----------------------
------ Traversal
-----------------------

-- | Warning: Linear in number of paths, exponential in size of graph.
--   Only use for very small graphs.
pathsMatching :: (Node -> Bool) -> Node -> [Path]
pathsMatching :: (Node -> Bool) -> Node -> [Path]
pathsMatching Node -> Bool
_   Node
EmptyNode = []
pathsMatching Node -> Bool
_   (Mu Node -> Node
_)    = [] -- Unsound!
pathsMatching Node -> Bool
f n :: Node
n@(Node [Edge]
es) = ([[Path]] -> [Path]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Path]] -> [Path]) -> [[Path]] -> [Path]
forall a b. (a -> b) -> a -> b
$ (Edge -> [Path]) -> [Edge] -> [[Path]]
forall a b. (a -> b) -> [a] -> [b]
map Edge -> [Path]
pathsMatchingEdge [Edge]
es)
                              [Path] -> [Path] -> [Path]
forall a. [a] -> [a] -> [a]
++ if Node -> Bool
f Node
n then [Path
EmptyPath] else []
  where
    pathsMatchingEdge :: Edge -> [Path]
    pathsMatchingEdge :: Edge -> [Path]
pathsMatchingEdge (Edge Symbol
_ [Node]
ns) = [[Path]] -> [Path]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Path]] -> [Path]) -> [[Path]] -> [Path]
forall a b. (a -> b) -> a -> b
$ (Int -> Node -> [Path]) -> [Node] -> [[Path]]
forall a b. (Int -> a -> b) -> [a] -> [b]
imap (\Int
i Node
x -> (Path -> Path) -> [Path] -> [Path]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Path -> Path
ConsPath Int
i) ([Path] -> [Path]) -> [Path] -> [Path]
forall a b. (a -> b) -> a -> b
$ (Node -> Bool) -> Node -> [Path]
pathsMatching Node -> Bool
f Node
x) [Node]
ns
pathsMatching Node -> Bool
_   (Rec RecNodeId
_)   = [Char] -> [Path]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [Path]) -> [Char] -> [Path]
forall a b. (a -> b) -> a -> b
$ [Char]
"pathsMatching: unexpected Rec"

-- | Precondition: For all i, f (Rec i) is either a Rec node meant to represent
--                 the enclosing Mu, or contains no Rec node not beneath another Mu.
mapNodes :: (Node -> Node) -> Node -> Node
mapNodes :: (Node -> Node) -> Node -> Node
mapNodes Node -> Node
f = Node -> Node
go
  where
    -- | Memoized separately for each mapNodes invocation
    go :: Node -> Node
    go :: Node -> Node
go = MemoCacheTag -> (Node -> Node) -> Node -> Node
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo (Text -> MemoCacheTag
NameTag Text
"mapNodes") Node -> Node
go'
    {-# NOINLINE go #-}

    go' :: Node -> Node
    go' :: Node -> Node
go' Node
EmptyNode = Node
EmptyNode
    go' (Node [Edge]
es) = Node -> Node
f (Node -> Node) -> Node -> Node
forall a b. (a -> b) -> a -> b
$ ([Edge] -> Node
Node ([Edge] -> Node) -> [Edge] -> Node
forall a b. (a -> b) -> a -> b
$ (Edge -> Edge) -> [Edge] -> [Edge]
forall a b. (a -> b) -> [a] -> [b]
map (\Edge
e -> Edge -> [Node] -> Edge
setChildren Edge
e ([Node] -> Edge) -> [Node] -> Edge
forall a b. (a -> b) -> a -> b
$ ((Node -> Node) -> [Node] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map Node -> Node
go (Edge -> [Node]
edgeChildren Edge
e))) [Edge]
es)
    go' (Mu Node -> Node
n)    = Node -> Node
f (Node -> Node) -> Node -> Node
forall a b. (a -> b) -> a -> b
$ (Node -> Node) -> Node
Mu (Node -> Node
go (Node -> Node) -> (Node -> Node) -> Node -> Node
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Node -> Node
n)
    go' (Rec RecNodeId
i)   = Node -> Node
f (Node -> Node) -> Node -> Node
forall a b. (a -> b) -> a -> b
$ RecNodeId -> Node
Rec RecNodeId
i

-- This name originates from the "crush" operator in the Stratego language. C.f.: the "crushtdT"
-- combinators in the KURE and compstrat libraries.
--
-- Although m is only constrained to be a monoid, crush makes no guarantees about ordering.
crush :: forall m. (Monoid m) => (Node -> m) -> Node -> m
crush :: (Node -> m) -> Node -> m
crush Node -> m
f = \Node
n -> State (Set Int) m -> Set Int -> m
forall s a. State s a -> s -> a
evalState (Monoid m => Node -> State (Set Int) m
Node -> State (Set Int) m
go Node
n) Set Int
forall a. Set a
Set.empty
  where
    go :: (Monoid m) => Node -> State (Set Id) m
    go :: Node -> State (Set Int) m
go Node
EmptyNode             = m -> State (Set Int) m
forall (m :: * -> *) a. Monad m => a -> m a
return m
forall a. Monoid a => a
mempty
    go (Rec RecNodeId
_)               = m -> State (Set Int) m
forall (m :: * -> *) a. Monad m => a -> m a
return m
forall a. Monoid a => a
mempty
    go n :: Node
n@(InternedMu InternedMu
mu)     = m -> m -> m
forall a. Monoid a => a -> a -> a
mappend (Node -> m
f Node
n) (m -> m) -> State (Set Int) m -> State (Set Int) m
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Monoid m => Node -> State (Set Int) m
Node -> State (Set Int) m
go (InternedMu -> Node
internedMuBody InternedMu
mu)
    go n :: Node
n@(InternedNode InternedNode
node) = do
      Set Int
seen <- StateT (Set Int) Identity (Set Int)
forall s (m :: * -> *). MonadState s m => m s
get
      let nId :: Int
nId = Node -> Int
nodeIdentity Node
n
      if Int -> Set Int -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Int
nId Set Int
seen then
        m -> State (Set Int) m
forall (m :: * -> *) a. Monad m => a -> m a
return m
forall a. Monoid a => a
mempty
       else do
        (Set Int -> Set Int) -> StateT (Set Int) Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (Int -> Set Int -> Set Int
forall a. Ord a => a -> Set a -> Set a
Set.insert Int
nId)
        m -> m -> m
forall a. Monoid a => a -> a -> a
mappend (Node -> m
f Node
n) (m -> m) -> State (Set Int) m -> State (Set Int) m
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([m] -> m
forall a. Monoid a => [a] -> a
mconcat ([m] -> m) -> StateT (Set Int) Identity [m] -> State (Set Int) m
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Edge -> State (Set Int) m)
-> [Edge] -> StateT (Set Int) Identity [m]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Edge Symbol
_ [Node]
ns) -> [m] -> m
forall a. Monoid a => [a] -> a
mconcat ([m] -> m) -> StateT (Set Int) Identity [m] -> State (Set Int) m
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Node -> State (Set Int) m)
-> [Node] -> StateT (Set Int) Identity [m]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Monoid m => Node -> State (Set Int) m
Node -> State (Set Int) m
go [Node]
ns) (InternedNode -> [Edge]
internedNodeEdges InternedNode
node))

onNormalNodes :: (Monoid m) => (Node -> m) -> (Node -> m)
onNormalNodes :: (Node -> m) -> Node -> m
onNormalNodes Node -> m
f n :: Node
n@(Node [Edge]
_) = Node -> m
f Node
n
onNormalNodes Node -> m
_ Node
_          = m
forall a. Monoid a => a
mempty

-----------------------
------ Folding
-----------------------

unfoldOuterRec :: Node -> Node
unfoldOuterRec :: Node -> Node
unfoldOuterRec n :: Node
n@(Mu Node -> Node
x) = Node -> Node
x Node
n
unfoldOuterRec Node
_        = [Char] -> Node
forall a. HasCallStack => [Char] -> a
error [Char]
"unfoldOuterRec: Must be called on a Mu node"

nodeEdges :: Node -> [Edge]
nodeEdges :: Node -> [Edge]
nodeEdges (Node [Edge]
es) = [Edge]
es
nodeEdges n :: Node
n@(Mu Node -> Node
_)  = Node -> [Edge]
nodeEdges (Node -> Node
unfoldOuterRec Node
n)
nodeEdges Node
_         = []

refold :: Node -> Node
refold :: Node -> Node
refold = MemoCacheTag -> (Node -> Node) -> Node -> Node
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo (Text -> MemoCacheTag
NameTag Text
"refold") Node -> Node
go
  where
    go :: Node -> Node
    go :: Node -> Node
go Node
n = if HashMap Node Node -> Bool
forall k v. HashMap k v -> Bool
HashMap.null HashMap Node Node
muNodeMap
             then Node
n
             else (Node -> Node) -> Node -> Node
forall a. Eq a => (a -> a) -> a -> a
fixUnbounded ((Node -> Node) -> Node -> Node
mapNodes Node -> Node
tryUnfold) Node
n
      where
        muNodeMap :: HashMap Node Node
muNodeMap = (Node -> HashMap Node Node) -> Node -> HashMap Node Node
forall m. Monoid m => (Node -> m) -> Node -> m
crush (\case x :: Node
x@(Mu Node -> Node
_) -> Node -> Node -> HashMap Node Node
forall k v. Hashable k => k -> v -> HashMap k v
HashMap.singleton (Node -> Node
unfoldOuterRec Node
x) Node
x
                                 Node
_        -> HashMap Node Node
forall k v. HashMap k v
HashMap.empty)
                          Node
n

        tryUnfold :: Node -> Node
tryUnfold Node
x = case Node -> HashMap Node Node -> Maybe Node
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HashMap.lookup Node
x HashMap Node Node
muNodeMap of
                        Just Node
y  -> Node
y
                        Maybe Node
Nothing -> Node
x

unfoldBounded :: Int -> Node -> Node
unfoldBounded :: Int -> Node -> Node
unfoldBounded Int
0 = (Node -> Node) -> Node -> Node
mapNodes (\case Mu Node -> Node
_ -> Node
EmptyNode
                                  Node
n    -> Node
n)
unfoldBounded Int
k = Int -> Node -> Node
unfoldBounded (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Node -> Node) -> (Node -> Node) -> Node -> Node
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Node -> Node) -> Node -> Node
mapNodes (\case n :: Node
n@(Mu Node -> Node
_) -> Node -> Node
unfoldOuterRec Node
n
                                                        Node
n        -> Node
n)


------------
------ Size operations
------------

nodeCount :: Node -> Int
nodeCount :: Node -> Int
nodeCount = Sum Int -> Int
forall a. Sum a -> a
getSum (Sum Int -> Int) -> (Node -> Sum Int) -> Node -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Node -> Sum Int) -> Node -> Sum Int
forall m. Monoid m => (Node -> m) -> Node -> m
crush ((Node -> Sum Int) -> Node -> Sum Int
forall m. Monoid m => (Node -> m) -> Node -> m
onNormalNodes ((Node -> Sum Int) -> Node -> Sum Int)
-> (Node -> Sum Int) -> Node -> Sum Int
forall a b. (a -> b) -> a -> b
$ Sum Int -> Node -> Sum Int
forall a b. a -> b -> a
const (Sum Int -> Node -> Sum Int) -> Sum Int -> Node -> Sum Int
forall a b. (a -> b) -> a -> b
$ Int -> Sum Int
forall a. a -> Sum a
Sum Int
1)

edgeCount :: Node -> Int
edgeCount :: Node -> Int
edgeCount = Sum Int -> Int
forall a. Sum a -> a
getSum (Sum Int -> Int) -> (Node -> Sum Int) -> Node -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Node -> Sum Int) -> Node -> Sum Int
forall m. Monoid m => (Node -> m) -> Node -> m
crush ((Node -> Sum Int) -> Node -> Sum Int
forall m. Monoid m => (Node -> m) -> Node -> m
onNormalNodes ((Node -> Sum Int) -> Node -> Sum Int)
-> (Node -> Sum Int) -> Node -> Sum Int
forall a b. (a -> b) -> a -> b
$ \(Node [Edge]
es) -> Int -> Sum Int
forall a. a -> Sum a
Sum ([Edge] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Edge]
es))

maxIndegree :: Node -> Int
maxIndegree :: Node -> Int
maxIndegree = Max Int -> Int
forall a. Max a -> a
getMax (Max Int -> Int) -> (Node -> Max Int) -> Node -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Node -> Max Int) -> Node -> Max Int
forall m. Monoid m => (Node -> m) -> Node -> m
crush ((Node -> Max Int) -> Node -> Max Int
forall m. Monoid m => (Node -> m) -> Node -> m
onNormalNodes ((Node -> Max Int) -> Node -> Max Int)
-> (Node -> Max Int) -> Node -> Max Int
forall a b. (a -> b) -> a -> b
$ \(Node [Edge]
es) -> Int -> Max Int
forall a. a -> Max a
Max ([Edge] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Edge]
es))

------------
------ Membership
------------

nodeRepresents :: Node -> Term -> Bool
nodeRepresents :: Node -> Term -> Bool
nodeRepresents Node
EmptyNode Term
_                      = Bool
False
nodeRepresents (Node [Edge]
es) Term
t                      = (Edge -> Bool) -> [Edge] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\Edge
e -> Edge -> Term -> Bool
edgeRepresents Edge
e Term
t) [Edge]
es
nodeRepresents n :: Node
n@(Mu Node -> Node
_)  Term
t                      = Node -> Term -> Bool
nodeRepresents (Node -> Node
unfoldOuterRec Node
n) Term
t
nodeRepresents Node
_         Term
_                      = Bool
False

edgeRepresents :: Edge -> Term -> Bool
edgeRepresents :: Edge -> Term -> Bool
edgeRepresents Edge
e = \t :: Term
t@(Term Symbol
s [Term]
ts) -> Symbol
s Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
== Edge -> Symbol
edgeSymbol Edge
e
                                  Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((Node -> Term -> Bool) -> [Node] -> [Term] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Node -> Term -> Bool
nodeRepresents (Edge -> [Node]
edgeChildren Edge
e) [Term]
ts)
                                  Bool -> Bool -> Bool
&& (PathEClass -> Bool) -> [PathEClass] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Term -> PathEClass -> Bool
eclassSatisfied Term
t) (EqConstraints -> [PathEClass]
unsafeGetEclasses (EqConstraints -> [PathEClass]) -> EqConstraints -> [PathEClass]
forall a b. (a -> b) -> a -> b
$ Edge -> EqConstraints
edgeEcs Edge
e)
  where
    eclassSatisfied :: Term -> PathEClass -> Bool
    eclassSatisfied :: Term -> PathEClass -> Bool
eclassSatisfied Term
t PathEClass
pec = [Maybe Term] -> Bool
forall a. Eq a => [a] -> Bool
allTheSame ([Maybe Term] -> Bool) -> [Maybe Term] -> Bool
forall a b. (a -> b) -> a -> b
$ (Path -> Maybe Term) -> [Path] -> [Maybe Term]
forall a b. (a -> b) -> [a] -> [b]
map (\Path
p -> Path -> Term -> Emptyable Term
forall t t'. Pathable t t' => Path -> t -> Emptyable t'
getPath Path
p Term
t) ([Path] -> [Maybe Term]) -> [Path] -> [Maybe Term]
forall a b. (a -> b) -> a -> b
$ PathEClass -> [Path]
unPathEClass PathEClass
pec

    allTheSame :: (Eq a) => [a] -> Bool
    allTheSame :: [a] -> Bool
allTheSame =
        \case
          []   -> Bool
True
          a
x:[a]
xs -> a -> [a] -> Bool
forall t. Eq t => t -> [t] -> Bool
go a
x [a]
xs
      where
        go :: t -> [t] -> Bool
go !t
_ []      = Bool
True
        go !t
x (!t
y:[t]
ys) = (t
x t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
y) Bool -> Bool -> Bool
&& (t -> [t] -> Bool
go t
x [t]
ys)
    {-# INLINE allTheSame #-}

------------
------ Intersect
------------

_oldIntersect :: Node -> Node -> Node
_oldIntersect :: Node -> Node -> Node
_oldIntersect = MemoCacheTag -> (Node -> Node -> Node) -> Node -> Node -> Node
forall a b c.
(Eq a, Hashable a, Eq b, Hashable b) =>
MemoCacheTag -> (a -> b -> c) -> a -> b -> c
memo2 (Text -> MemoCacheTag
NameTag Text
"intersect") Node -> Node -> Node
go
  where
    go :: Node -> Node -> Node
    go :: Node -> Node -> Node
go Node
n1 Node
n2 = Node -> Node
refold (Node -> Node
nodeDropRedundantEdges (Node -> Node -> Node
doIntersect Node
n1 Node
n2))
{-# NOINLINE intersect #-}


-- 7/4/21: The unrolling strategy for intersection totally does not generalize beyond
-- recursive nodes which have a self cycle.
--
-- The following will enter an infinite recursion:
--  > t = createGloballyUniqueMu (\n -> Node  [Edge "a" [Node [Edge "a" [n]]]])
--  > intersect t (Node [Edge "a" [t]])
doIntersect :: Node -> Node -> Node
doIntersect :: Node -> Node -> Node
doIntersect Node
EmptyNode Node
_         = Node
EmptyNode
doIntersect Node
_         Node
EmptyNode = Node
EmptyNode
doIntersect n :: Node
n@(Mu Node -> Node
_)  (Mu Node -> Node
_)    = Node
n -- TODO: Update for multiple Mu's
doIntersect n1 :: Node
n1@(Mu Node -> Node
_) Node
n2        = Node -> Node -> Node
doIntersect (Node -> Node
unfoldOuterRec Node
n1) Node
n2
doIntersect Node
n1        n2 :: Node
n2@(Mu Node -> Node
_) = Node -> Node -> Node
doIntersect Node
n1                  (Node -> Node
unfoldOuterRec Node
n2)
doIntersect n1 :: Node
n1@(Node [Edge]
es1) n2 :: Node
n2@(Node [Edge]
es2)
  | Node
n1 Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== Node
n2                            = Node
n1
  | Node
n2 Node -> Node -> Bool
forall a. Ord a => a -> a -> Bool
<  Node
n1                            = Node -> Node -> Node
intersect Node
n2 Node
n1
                                          -- `hash` gives a unique ID of the symbol because they're interned
  | Bool
otherwise                           = let joined :: [Edge]
joined = (Edge -> Int)
-> (Edge -> Edge -> Edge) -> [Edge] -> [Edge] -> [Edge]
forall a b. (a -> Int) -> (a -> a -> b) -> [a] -> [a] -> [b]
hashJoin (Symbol -> Int
forall a. Hashable a => a -> Int
hash (Symbol -> Int) -> (Edge -> Symbol) -> Edge -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Edge -> Symbol
edgeSymbol) Edge -> Edge -> Edge
intersectEdgeSameSymbol [Edge]
es1 [Edge]
es2
                                          in [Edge] -> Node
Node [Edge]
joined
                                             --Node $ dropRedundantEdges joined
                                             --mkNodeAlreadyNubbed $ dropRedundantEdges joined
doIntersect Node
n1 Node
n2 = [Char] -> Node
forall a. HasCallStack => [Char] -> a
error ([Char] -> Node) -> [Char] -> Node
forall a b. (a -> b) -> a -> b
$ [Char]
"doIntersect: Unexpected " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Node -> [Char]
forall a. Show a => a -> [Char]
show Node
n1 [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Node -> [Char]
forall a. Show a => a -> [Char]
show Node
n2


nodeDropRedundantEdges :: Node -> Node
nodeDropRedundantEdges :: Node -> Node
nodeDropRedundantEdges (Node [Edge]
es) = [Edge] -> Node
Node ([Edge] -> Node) -> [Edge] -> Node
forall a b. (a -> b) -> a -> b
$ [Edge] -> [Edge]
dropRedundantEdges [Edge]
es
nodeDropRedundantEdges Node
n         = Node
n

data RuleOutRes = Keep | RuledOutBy Edge

dropRedundantEdges :: [Edge] -> [Edge]
dropRedundantEdges :: [Edge] -> [Edge]
dropRedundantEdges [Edge]
origEs = ([Edge] -> [Edge]) -> [[Edge]] -> [Edge]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [Edge] -> [Edge]
reduceCluster ([[Edge]] -> [Edge]) -> [[Edge]] -> [Edge]
forall a b. (a -> b) -> a -> b
$ {- traceShow (map (\es -> (length es, edgeSymbol $ head es)) clusters, length $ concatMap reduceCluster clusters)-} [[Edge]]
clusters
  where
    clusters :: [[Edge]]
clusters = ([Edge] -> [Edge]) -> [[Edge]] -> [[Edge]]
forall a b. (a -> b) -> [a] -> [b]
map ((Edge -> Int) -> [Edge] -> [Edge]
forall a. (a -> Int) -> [a] -> [a]
nubByIdSinglePass Edge -> Int
edgeId) ([[Edge]] -> [[Edge]]) -> [[Edge]] -> [[Edge]]
forall a b. (a -> b) -> a -> b
$ (Edge -> Int) -> [Edge] -> [[Edge]]
forall a. (a -> Int) -> [a] -> [[a]]
clusterByHash (Symbol -> Int
forall a. Hashable a => a -> Int
hash (Symbol -> Int) -> (Edge -> Symbol) -> Edge -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Edge -> Symbol
edgeSymbol) [Edge]
origEs

    reduceCluster :: [Edge] -> [Edge]
    reduceCluster :: [Edge] -> [Edge]
reduceCluster []     = []
    reduceCluster (Edge
e:[Edge]
es) = case Edge -> [Edge] -> (RuleOutRes, [Edge])
ruleOut Edge
e [Edge]
es of
                             -- Optimization: If e' > e, likely to be greater than other things;
                             -- move it to front and rule out more stuff next iteration.
                             --
                             -- No noticeable difference in overall wall clock time (7/2/21),
                             -- but a few % reduction in calls to intersectEdgeSameSymbol
                             (RuledOutBy Edge
e', [Edge]
es') -> [Edge] -> [Edge]
reduceCluster (Edge
e'Edge -> [Edge] -> [Edge]
forall a. a -> [a] -> [a]
:[Edge]
es')
                             (RuleOutRes
Keep, [Edge]
es') -> Edge
e Edge -> [Edge] -> [Edge]
forall a. a -> [a] -> [a]
: [Edge] -> [Edge]
reduceCluster [Edge]
es'

    ruleOut :: Edge -> [Edge] -> (RuleOutRes, [Edge])
    ruleOut :: Edge -> [Edge] -> (RuleOutRes, [Edge])
ruleOut Edge
_ []     = (RuleOutRes
Keep, [])
    ruleOut Edge
e (Edge
x:[Edge]
xs) = let e' :: Edge
e' = Edge -> Edge -> Edge
intersectEdgeSameSymbol Edge
e Edge
x in
                       if Edge
e' Edge -> Edge -> Bool
forall a. Eq a => a -> a -> Bool
== Edge
x then
                         Edge -> [Edge] -> (RuleOutRes, [Edge])
ruleOut Edge
e [Edge]
xs
                       else if Edge
e' Edge -> Edge -> Bool
forall a. Eq a => a -> a -> Bool
== Edge
e then
                         (Edge -> RuleOutRes
RuledOutBy Edge
x, [Edge]
xs)
                       else
                         let (RuleOutRes
res, [Edge]
notRuledOut) = Edge -> [Edge] -> (RuleOutRes, [Edge])
ruleOut Edge
e [Edge]
xs
                         in (RuleOutRes
res, Edge
x Edge -> [Edge] -> [Edge]
forall a. a -> [a] -> [a]
: [Edge]
notRuledOut)

intersectEdge :: Edge -> Edge -> Maybe Edge
intersectEdge :: Edge -> Edge -> Maybe Edge
intersectEdge Edge
e1 Edge
e2
  | Edge -> Symbol
edgeSymbol Edge
e1 Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
/= Edge -> Symbol
edgeSymbol Edge
e2 = Maybe Edge
forall a. Maybe a
Nothing
  | Bool
otherwise                      = Edge -> Maybe Edge
forall a. a -> Maybe a
Just (Edge -> Maybe Edge) -> Edge -> Maybe Edge
forall a b. (a -> b) -> a -> b
$ Edge -> Edge -> Edge
intersectEdgeSameSymbol Edge
e1 Edge
e2

intersectEdgeSameSymbol :: Edge -> Edge -> Edge
intersectEdgeSameSymbol :: Edge -> Edge -> Edge
intersectEdgeSameSymbol = MemoCacheTag -> (Edge -> Edge -> Edge) -> Edge -> Edge -> Edge
forall a b c.
(Eq a, Hashable a, Eq b, Hashable b) =>
MemoCacheTag -> (a -> b -> c) -> a -> b -> c
memo2 (Text -> MemoCacheTag
NameTag Text
"intersectEdgeSameSymbol") Edge -> Edge -> Edge
go
  where
    go :: Edge -> Edge -> Edge
go Edge
e1          Edge
e2
      | Edge
e2 Edge -> Edge -> Bool
forall a. Ord a => a -> a -> Bool
< Edge
e1                                         = Edge -> Edge -> Edge
intersectEdgeSameSymbol Edge
e2 Edge
e1
#ifdef DEFENSIVE_CHECKS
    go (Edge s children1) (Edge _ children2)
      | length children1 /= length children2            = error $ "Different lengths encountered for children of symbol " <> show s
#endif
    go Edge
e1                 Edge
e2                 =
        Symbol -> [Node] -> EqConstraints -> Edge
mkEdge (Edge -> Symbol
edgeSymbol Edge
e1)
               ((Node -> Node -> Node) -> [Node] -> [Node] -> [Node]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Node -> Node -> Node
intersect (Edge -> [Node]
edgeChildren Edge
e1) (Edge -> [Node]
edgeChildren Edge
e2))
               (Edge -> EqConstraints
edgeEcs Edge
e1 EqConstraints -> EqConstraints -> EqConstraints
`combineEqConstraints` Edge -> EqConstraints
edgeEcs Edge
e2)
{-# NOINLINE intersectEdgeSameSymbol #-}

------------
------ New intersection
------------

intersect :: Node -> Node -> Node
intersect :: Node -> Node -> Node
intersect Node
l Node
r = (IntersectionDom, Node, Node) -> Node
intersectOpen (IntersectionDom
emptyIntersectionDom, Node
l, Node
r)

------ Intersection internals

-- | Intersection domain
--
-- Information required to compute the intersection of open terms.
data IntersectionDom = ID {
      -- | Value of all free variables inside the term (so that we can unfold when necessary)
      IntersectionDom -> Map Int Node
idFree :: Map Id Node

      -- | Intersection problems we encountered previously (to avoid infinite unrolling)
    , IntersectionDom -> Set IntersectId
idRecInt :: Set IntersectId
    }
  deriving (Int -> IntersectionDom -> [Char] -> [Char]
[IntersectionDom] -> [Char] -> [Char]
IntersectionDom -> [Char]
(Int -> IntersectionDom -> [Char] -> [Char])
-> (IntersectionDom -> [Char])
-> ([IntersectionDom] -> [Char] -> [Char])
-> Show IntersectionDom
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [IntersectionDom] -> [Char] -> [Char]
$cshowList :: [IntersectionDom] -> [Char] -> [Char]
show :: IntersectionDom -> [Char]
$cshow :: IntersectionDom -> [Char]
showsPrec :: Int -> IntersectionDom -> [Char] -> [Char]
$cshowsPrec :: Int -> IntersectionDom -> [Char] -> [Char]
Show, IntersectionDom -> IntersectionDom -> Bool
(IntersectionDom -> IntersectionDom -> Bool)
-> (IntersectionDom -> IntersectionDom -> Bool)
-> Eq IntersectionDom
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IntersectionDom -> IntersectionDom -> Bool
$c/= :: IntersectionDom -> IntersectionDom -> Bool
== :: IntersectionDom -> IntersectionDom -> Bool
$c== :: IntersectionDom -> IntersectionDom -> Bool
Eq)

instance Hashable IntersectionDom where
  -- Implementation notes:
  --
  -- - Both `Map.toList` and `Set.toList` return elements in key-order, which is a suitable canonical form for hashing.
  -- - The cost of the hashing is linear in the size of the domain. If this becomes a concern, we could cache the hash.
  hashWithSalt :: Int -> IntersectionDom -> Int
hashWithSalt Int
s (ID Map Int Node
free Set IntersectId
recInt) = Int -> ([(Int, Node)], [IntersectId]) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (Map Int Node -> [(Int, Node)]
forall k a. Map k a -> [(k, a)]
Map.toList Map Int Node
free, Set IntersectId -> [IntersectId]
forall a. Set a -> [a]
Set.toList Set IntersectId
recInt)

emptyIntersectionDom :: IntersectionDom
emptyIntersectionDom :: IntersectionDom
emptyIntersectionDom = Map Int Node -> Set IntersectId -> IntersectionDom
ID Map Int Node
forall k a. Map k a
Map.empty Set IntersectId
forall a. Set a
Set.empty

intersectOpen :: (IntersectionDom, Node, Node) -> Node
{-# NOINLINE intersectOpen #-}
intersectOpen :: (IntersectionDom, Node, Node) -> Node
intersectOpen = MemoCacheTag
-> ((IntersectionDom, Node, Node) -> Node)
-> (IntersectionDom, Node, Node)
-> Node
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo (Text -> MemoCacheTag
NameTag Text
"intersectOpen") (\(IntersectionDom
dom, Node
l, Node
r) -> Node -> Node
refold (Node -> Node) -> Node -> Node
forall a b. (a -> b) -> a -> b
$ Node -> Node
nodeDropRedundantEdges (Node -> Node) -> Node -> Node
forall a b. (a -> b) -> a -> b
$ IntersectionDom -> Node -> Node -> Node
onNode IntersectionDom
dom Node
l Node
r)
  where
    onNode :: IntersectionDom -> Node -> Node -> Node
    onNode :: IntersectionDom -> Node -> Node -> Node
onNode !IntersectionDom
dom Node
l Node
r =
        case (Node
l, Node
r) of
          -- Rule out empty cases first
          -- This justifies the use of nodeIdentity (@i@, @j@) for the other cases
          (Node
EmptyNode, Node
_) -> Node
EmptyNode
          (Node
_, Node
EmptyNode) -> Node
EmptyNode

          -- For closed terms, improve memoization performance by using the empty environment
          (Node, Node)
_ | Set RecNodeId -> Bool
forall a. Set a -> Bool
Set.null (Node -> Set RecNodeId
freeVars Node
l), Set RecNodeId -> Bool
forall a. Set a -> Bool
Set.null (Node -> Set RecNodeId
freeVars Node
r), Bool -> Bool
not (Map Int Node -> Bool
forall k a. Map k a -> Bool
Map.null (IntersectionDom -> Map Int Node
idFree IntersectionDom
dom)) -> Node -> Node -> Node
intersect Node
l Node
r

          -- Special case for self-intersection (equality check is cheap of course: just uses the interned 'Id')
          (Node, Node)
_ | Node
l Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== Node
r, Set RecNodeId -> Bool
forall a. Set a -> Bool
Set.null (Node -> Set RecNodeId
freeVars Node
l) -> Node
l

          -- Always intersect nodes in the same order. This is important for two reasons:
          --
          -- 1. It will increase the probability of a cache hit (i.e., improve memoization)
          -- 2. It will increase the probability of being able to use 'ieRecInt'
          (Node, Node)
_ | Node
l Node -> Node -> Bool
forall a. Ord a => a -> a -> Bool
> Node
r -> (IntersectionDom, Node, Node) -> Node
intersectOpen (IntersectionDom
dom, Node
r, Node
l)

          -- If we have seen this exact problem before, refer to enclosing Mu.
          (Node, Node)
_ | IntersectId -> Set IntersectId -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (Int -> Int -> IntersectId
IntersectId Int
i Int
j) (IntersectionDom -> Set IntersectId
idRecInt IntersectionDom
dom) -> RecNodeId -> Node
Rec (IntersectId -> RecNodeId
RecIntersect (Int -> Int -> IntersectId
IntersectId Int
i Int
j))

          -- When encountering a 'Mu', extend the domain appropriately.
          (InternedMu InternedMu
l' , InternedMu InternedMu
r') -> Node -> Node
maybeMu (Node -> Node) -> Node -> Node
forall a b. (a -> b) -> a -> b
$ (IntersectionDom, Node, Node) -> Node
intersectOpen ([(Int, Node)] -> IntersectionDom
extendEnv [(Int
i, Node
l), (Int
j, Node
r)] , InternedMu -> Node
internedMuBody InternedMu
l' , InternedMu -> Node
internedMuBody InternedMu
r')
          (InternedMu InternedMu
l' , Node
_            ) -> Node -> Node
maybeMu (Node -> Node) -> Node -> Node
forall a b. (a -> b) -> a -> b
$ (IntersectionDom, Node, Node) -> Node
intersectOpen ([(Int, Node)] -> IntersectionDom
extendEnv [(Int
i, Node
l)        ] , InternedMu -> Node
internedMuBody InternedMu
l' ,                Node
r )
          (Node
_             , InternedMu InternedMu
r') -> Node -> Node
maybeMu (Node -> Node) -> Node -> Node
forall a b. (a -> b) -> a -> b
$ (IntersectionDom, Node, Node) -> Node
intersectOpen ([(Int, Node)] -> IntersectionDom
extendEnv [        (Int
j, Node
r)] ,                Node
l  , InternedMu -> Node
internedMuBody InternedMu
r')

           -- When encountering a free variable, look up the corresponding value in the environment.
           -- (Recall that the case for already-seen intersection problems is are handled above.)
          (Rec RecNodeId
l' , Node
_     ) -> (IntersectionDom, Node, Node) -> Node
intersectOpen (IntersectionDom
dom , RecNodeId -> Node
findFreeVar RecNodeId
l' ,             Node
r )
          (Node
_      , Rec RecNodeId
r') -> (IntersectionDom, Node, Node) -> Node
intersectOpen (IntersectionDom
dom ,             Node
l  , RecNodeId -> Node
findFreeVar RecNodeId
r')

          -- Finally, the real intersection work happens here
          (InternedNode InternedNode
l', InternedNode InternedNode
r') ->
            [Edge] -> Node
Node ([Edge] -> Node) -> [Edge] -> Node
forall a b. (a -> b) -> a -> b
$ (Edge -> Int)
-> (Edge -> Edge -> Edge) -> [Edge] -> [Edge] -> [Edge]
forall a b. (a -> Int) -> (a -> a -> b) -> [a] -> [a] -> [b]
hashJoin (Symbol -> Int
forall a. Hashable a => a -> Int
hash (Symbol -> Int) -> (Edge -> Symbol) -> Edge -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Edge -> Symbol
edgeSymbol)
                            (\Edge
e Edge
e' -> (IntersectionDom, Edge, Edge) -> Edge
intersectOpenEdge (IntersectionDom
dom, Edge
e, Edge
e'))
                            (InternedNode -> [Edge]
internedNodeEdges InternedNode
l')
                            (InternedNode -> [Edge]
internedNodeEdges InternedNode
r')
      where
        -- Node identities (should only be used (forced) if previously established the nodes are not empty)
        i, j :: Id
        i :: Int
i = Node -> Int
nodeIdentity Node
l
        j :: Int
j = Node -> Int
nodeIdentity Node
r

        -- Extend domain when we encounter a 'Mu'
        -- We might see one or two 'Mu's (if we happen to see a 'Mu' on both sides at once)
        extendEnv :: [(Id, Node)] -> IntersectionDom
        extendEnv :: [(Int, Node)] -> IntersectionDom
extendEnv [(Int, Node)]
bindings = ID :: Map Int Node -> Set IntersectId -> IntersectionDom
ID {
              idFree :: Map Int Node
idFree   = Map Int Node -> Map Int Node -> Map Int Node
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union ([(Int, Node)] -> Map Int Node
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Int, Node)]
bindings) (IntersectionDom -> Map Int Node
idFree IntersectionDom
dom)
            , idRecInt :: Set IntersectId
idRecInt = IntersectId -> Set IntersectId -> Set IntersectId
forall a. Ord a => a -> Set a -> Set a
Set.insert (Int -> Int -> IntersectId
IntersectId Int
i Int
j) (IntersectionDom -> Set IntersectId
idRecInt IntersectionDom
dom)
            }

        -- Find value of free variables in the terms
        -- Since we assume the input terms are fully interned, we only deal with 'RecInt'.
        findFreeVar :: RecNodeId -> Node
        findFreeVar :: RecNodeId -> Node
findFreeVar (RecInt Int
intId) | Just Node
n <- Int -> Map Int Node -> Maybe Node
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Int
intId (IntersectionDom -> Map Int Node
idFree IntersectionDom
dom) = Node
n
        findFreeVar RecNodeId
recId = [Char] -> Node
forall a. HasCallStack => [Char] -> a
error ([Char] -> Node) -> [Char] -> Node
forall a b. (a -> b) -> a -> b
$ [Char]
"findFreeVar: unexpected " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> RecNodeId -> [Char]
forall a. Show a => a -> [Char]
show RecNodeId
recId

        -- We only insert a 'Mu' node when necessary.
        maybeMu :: Node -> Node
        maybeMu :: Node -> Node
maybeMu Node
n
          | IntersectId -> RecNodeId
RecIntersect (Int -> Int -> IntersectId
IntersectId Int
i Int
j) RecNodeId -> Set RecNodeId -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Node -> Set RecNodeId
freeVars Node
n
          = (Node -> Node) -> Node
Mu ((Node -> Node) -> Node) -> (Node -> Node) -> Node
forall a b. (a -> b) -> a -> b
$ \Node
recNode -> RecNodeId -> Node -> Node -> Node
substFree (IntersectId -> RecNodeId
RecIntersect (Int -> Int -> IntersectId
IntersectId Int
i Int
j)) Node
recNode Node
n

          | Bool
otherwise
          = Node
n

-- | Auxiliary to 'intersectOpen'.
intersectOpenEdge :: (IntersectionDom, Edge, Edge) -> Edge
{-# NOINLINE intersectOpenEdge #-}
intersectOpenEdge :: (IntersectionDom, Edge, Edge) -> Edge
intersectOpenEdge = MemoCacheTag
-> ((IntersectionDom, Edge, Edge) -> Edge)
-> (IntersectionDom, Edge, Edge)
-> Edge
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo (Text -> MemoCacheTag
NameTag Text
"intersectOpenEdge") (\(IntersectionDom
dom, Edge
l, Edge
r) -> IntersectionDom -> Edge -> Edge -> Edge
onEdge IntersectionDom
dom Edge
l Edge
r)
  where
    onEdge :: IntersectionDom -> Edge -> Edge -> Edge
    onEdge :: IntersectionDom -> Edge -> Edge -> Edge
onEdge !IntersectionDom
dom Edge
l Edge
r =
         Symbol -> [Node] -> EqConstraints -> Edge
mkEdge (Edge -> Symbol
edgeSymbol Edge
l)
                ((Node -> Node -> Node) -> [Node] -> [Node] -> [Node]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Node
a Node
b -> (IntersectionDom, Node, Node) -> Node
intersectOpen (IntersectionDom
dom, Node
a, Node
b)) (Edge -> [Node]
edgeChildren Edge
l) (Edge -> [Node]
edgeChildren Edge
r))
                (Edge -> EqConstraints
edgeEcs Edge
l EqConstraints -> EqConstraints -> EqConstraints
`combineEqConstraints` Edge -> EqConstraints
edgeEcs Edge
r)

------------
------ Union
------------

union :: [Node] -> Node
union :: [Node] -> Node
union [Node]
ns = case (Node -> Bool) -> [Node] -> [Node]
forall a. (a -> Bool) -> [a] -> [a]
filter (Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
/= Node
EmptyNode) [Node]
ns of
             []  -> Node
EmptyNode
             [Node]
ns' -> [Edge] -> Node
Node ([[Edge]] -> [Edge]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Edge]] -> [Edge]) -> [[Edge]] -> [Edge]
forall a b. (a -> b) -> a -> b
$ (Node -> [Edge]) -> [Node] -> [[Edge]]
forall a b. (a -> b) -> [a] -> [b]
map Node -> [Edge]
nodeEdges [Node]
ns')

----------------------
------ Path operations
----------------------

requirePath :: Path -> Node -> Node
requirePath :: Path -> Node -> Node
requirePath Path
EmptyPath       Node
n         = Node
n
requirePath Path
_               Node
EmptyNode = Node
EmptyNode
requirePath Path
p               n :: Node
n@(Mu Node -> Node
_)  = Path -> Node -> Node
requirePath Path
p (Node -> Node
unfoldOuterRec Node
n)
requirePath (ConsPath Int
p Path
ps) (Node [Edge]
es) = [Edge] -> Node
Node ([Edge] -> Node) -> [Edge] -> Node
forall a b. (a -> b) -> a -> b
$ (Edge -> Edge) -> [Edge] -> [Edge]
forall a b. (a -> b) -> [a] -> [b]
map (\Edge
e -> Edge -> [Node] -> Edge
setChildren Edge
e (Path -> [Node] -> [Node]
requirePathList (Int -> Path -> Path
ConsPath Int
p Path
ps) (Edge -> [Node]
edgeChildren Edge
e)))
                                             ([Edge] -> [Edge]) -> [Edge] -> [Edge]
forall a b. (a -> b) -> a -> b
$ (Edge -> Bool) -> [Edge] -> [Edge]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Edge
e -> [Node] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Edge -> [Node]
edgeChildren Edge
e) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
p)
                                                      [Edge]
es
requirePath Path
_               (Rec RecNodeId
_)   = [Char] -> Node
forall a. HasCallStack => [Char] -> a
error [Char]
"requirePath: unexpected Rec"

requirePathList :: Path -> [Node] -> [Node]
requirePathList :: Path -> [Node] -> [Node]
requirePathList Path
EmptyPath       [Node]
ns = [Node]
ns
requirePathList (ConsPath Int
p Path
ps) [Node]
ns = [Node]
ns [Node] -> ([Node] -> [Node]) -> [Node]
forall a b. a -> (a -> b) -> b
& Index [Node] -> Traversal' [Node] (IxValue [Node])
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index [Node]
p ((Node -> Identity Node) -> [Node] -> Identity [Node])
-> (Node -> Node) -> [Node] -> [Node]
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Path -> Node -> Node
requirePath Path
ps

instance Pathable Node Node where
  type Emptyable Node = Node

  getPath :: Path -> Node -> Emptyable Node
getPath Path
_                Node
EmptyNode = Emptyable Node
Node
EmptyNode
  getPath Path
EmptyPath        Node
n         = Emptyable Node
Node
n
  getPath Path
p                n :: Node
n@(Mu Node -> Node
_)  = Path -> Node -> Emptyable Node
forall t t'. Pathable t t' => Path -> t -> Emptyable t'
getPath Path
p (Node -> Node
unfoldOuterRec Node
n)
  getPath (ConsPath Int
p Path
ps) (Node [Edge]
es)  = [Node] -> Node
union ([Node] -> Node) -> [Node] -> Node
forall a b. (a -> b) -> a -> b
$ (Node -> Node) -> [Node] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map (Path -> Node -> Emptyable Node
forall t t'. Pathable t t' => Path -> t -> Emptyable t'
getPath Path
ps) ([Maybe Node] -> [Node]
forall a. [Maybe a] -> [a]
catMaybes ((Edge -> Maybe Node) -> [Edge] -> [Maybe Node]
forall a b. (a -> b) -> [a] -> [b]
map Edge -> Maybe Node
goEdge [Edge]
es))
    where
      goEdge :: Edge -> Maybe Node
      goEdge :: Edge -> Maybe Node
goEdge (Edge Symbol
_ [Node]
ns) = [Node]
ns [Node] -> Getting (First Node) [Node] Node -> Maybe Node
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index [Node] -> Traversal' [Node] (IxValue [Node])
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index [Node]
p
  getPath Path
p                Node
n         = [Char] -> Node
forall a. HasCallStack => [Char] -> a
error ([Char] -> Node) -> [Char] -> Node
forall a b. (a -> b) -> a -> b
$ [Char]
"getPath: unexpected path " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Path -> [Char]
forall a. Show a => a -> [Char]
show Path
p [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" for node " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Node -> [Char]
forall a. Show a => a -> [Char]
show Node
n

  getAllAtPath :: Path -> Node -> [Node]
getAllAtPath Path
_               Node
EmptyNode = []
  getAllAtPath Path
EmptyPath       Node
n         = [Node
n]
  getAllAtPath Path
p               n :: Node
n@(Mu Node -> Node
_)  = Path -> Node -> [Node]
forall t t'. Pathable t t' => Path -> t -> [t']
getAllAtPath Path
p (Node -> Node
unfoldOuterRec Node
n)
  getAllAtPath (ConsPath Int
p Path
ps) (Node [Edge]
es) = (Node -> [Node]) -> [Node] -> [Node]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Path -> Node -> [Node]
forall t t'. Pathable t t' => Path -> t -> [t']
getAllAtPath Path
ps) ([Maybe Node] -> [Node]
forall a. [Maybe a] -> [a]
catMaybes ((Edge -> Maybe Node) -> [Edge] -> [Maybe Node]
forall a b. (a -> b) -> [a] -> [b]
map Edge -> Maybe Node
goEdge [Edge]
es))
    where
      goEdge :: Edge -> Maybe Node
      goEdge :: Edge -> Maybe Node
goEdge (Edge Symbol
_ [Node]
ns) = [Node]
ns [Node] -> Getting (First Node) [Node] Node -> Maybe Node
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index [Node] -> Traversal' [Node] (IxValue [Node])
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index [Node]
p
  getAllAtPath Path
p               Node
n         = [Char] -> [Node]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [Node]) -> [Char] -> [Node]
forall a b. (a -> b) -> a -> b
$ [Char]
"getAllAtPath: unexpected path " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Path -> [Char]
forall a. Show a => a -> [Char]
show Path
p [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" for node " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Node -> [Char]
forall a. Show a => a -> [Char]
show Node
n

  modifyAtPath :: (Node -> Node) -> Path -> Node -> Node
modifyAtPath Node -> Node
f Path
EmptyPath       Node
n         = Node -> Node
f Node
n
  modifyAtPath Node -> Node
_ Path
_               Node
EmptyNode = Node
EmptyNode
  modifyAtPath Node -> Node
f Path
p               n :: Node
n@(Mu Node -> Node
_)  = (Node -> Node) -> Path -> Node -> Node
forall t t'. Pathable t t' => (t' -> t') -> Path -> t -> t
modifyAtPath Node -> Node
f Path
p (Node -> Node
unfoldOuterRec Node
n)
  modifyAtPath Node -> Node
f (ConsPath Int
p Path
ps) (Node [Edge]
es) = [Edge] -> Node
Node ((Edge -> Edge) -> [Edge] -> [Edge]
forall a b. (a -> b) -> [a] -> [b]
map Edge -> Edge
goEdge [Edge]
es)
    where
      goEdge :: Edge -> Edge
      goEdge :: Edge -> Edge
goEdge Edge
e = Edge -> [Node] -> Edge
setChildren Edge
e (Edge -> [Node]
edgeChildren Edge
e [Node] -> ([Node] -> [Node]) -> [Node]
forall a b. a -> (a -> b) -> b
& Index [Node] -> Traversal' [Node] (IxValue [Node])
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index [Node]
p ((Node -> Identity Node) -> [Node] -> Identity [Node])
-> (Node -> Node) -> [Node] -> [Node]
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (Node -> Node) -> Path -> Node -> Node
forall t t'. Pathable t t' => (t' -> t') -> Path -> t -> t
modifyAtPath Node -> Node
f Path
ps)
  modifyAtPath Node -> Node
_ Path
p               Node
n         = [Char] -> Node
forall a. HasCallStack => [Char] -> a
error ([Char] -> Node) -> [Char] -> Node
forall a b. (a -> b) -> a -> b
$ [Char]
"modifyAtPath: unexpected path " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Path -> [Char]
forall a. Show a => a -> [Char]
show Path
p [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" for node " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Node -> [Char]
forall a. Show a => a -> [Char]
show Node
n

instance Pathable [Node] Node where
  type Emptyable Node = Node

  getPath :: Path -> [Node] -> Emptyable Node
getPath Path
EmptyPath       [Node]
ns = [Node] -> Node
union [Node]
ns
  getPath (ConsPath Int
p Path
ps) [Node]
ns = case [Node]
ns [Node] -> Getting (First Node) [Node] Node -> Maybe Node
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index [Node] -> Traversal' [Node] (IxValue [Node])
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index [Node]
p of
                                 Maybe Node
Nothing -> Emptyable Node
Node
EmptyNode
                                 Just Node
n  -> Path -> Node -> Emptyable Node
forall t t'. Pathable t t' => Path -> t -> Emptyable t'
getPath Path
ps Node
n

  getAllAtPath :: Path -> [Node] -> [Node]
getAllAtPath Path
EmptyPath       [Node]
_  = []
  getAllAtPath (ConsPath Int
p Path
ps) [Node]
ns = case [Node]
ns [Node] -> Getting (First Node) [Node] Node -> Maybe Node
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index [Node] -> Traversal' [Node] (IxValue [Node])
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index [Node]
p of
                                      Maybe Node
Nothing -> []
                                      Just Node
n  -> Path -> Node -> [Node]
forall t t'. Pathable t t' => Path -> t -> [t']
getAllAtPath Path
ps Node
n

  modifyAtPath :: (Node -> Node) -> Path -> [Node] -> [Node]
modifyAtPath Node -> Node
_ Path
EmptyPath       [Node]
ns = [Node]
ns
  modifyAtPath Node -> Node
f (ConsPath Int
p Path
ps) [Node]
ns = [Node]
ns [Node] -> ([Node] -> [Node]) -> [Node]
forall a b. a -> (a -> b) -> b
& Index [Node] -> Traversal' [Node] (IxValue [Node])
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Int
Index [Node]
p ((Node -> Identity Node) -> [Node] -> Identity [Node])
-> (Node -> Node) -> [Node] -> [Node]
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (Node -> Node) -> Path -> Node -> Node
forall t t'. Pathable t t' => (t' -> t') -> Path -> t -> t
modifyAtPath Node -> Node
f Path
ps



------------------------------------
------ Reduction
------------------------------------

withoutRedundantEdges :: Node -> Node
withoutRedundantEdges :: Node -> Node
withoutRedundantEdges Node
n = (Node -> Node) -> Node -> Node
mapNodes Node -> Node
dropReds Node
n
  where
    dropReds :: Node -> Node
dropReds (Node [Edge]
es) = [Edge] -> Node
Node ([Edge] -> [Edge]
dropRedundantEdges [Edge]
es)
    dropReds Node
x         = Node
x

---------------
--- Reducing Equality Constraints
---------------

reducePartially :: Node -> Node
reducePartially :: Node -> Node
reducePartially = EqConstraints -> Node -> Node
reducePartially' EqConstraints
EmptyConstraints

reducePartially' :: EqConstraints -> Node -> Node
reducePartially' :: EqConstraints -> Node -> Node
reducePartially' = MemoCacheTag
-> (EqConstraints -> Node -> Node) -> EqConstraints -> Node -> Node
forall a b c.
(Eq a, Hashable a, Eq b, Hashable b) =>
MemoCacheTag -> (a -> b -> c) -> a -> b -> c
memo2 (Text -> MemoCacheTag
NameTag Text
"reducePartially'") EqConstraints -> Node -> Node
go
  where
    go :: EqConstraints -> Node -> Node
    go :: EqConstraints -> Node -> Node
go EqConstraints
_            Node
EmptyNode  = Node
EmptyNode
    go EqConstraints
_            (Mu Node -> Node
n)     = (Node -> Node) -> Node
Mu Node -> Node
n
    go EqConstraints
inheritedEcs n :: Node
n@(Node [Edge]
_) = Node -> ([Edge] -> [Edge]) -> Node
modifyNode Node
n (([Edge] -> [Edge]) -> Node) -> ([Edge] -> [Edge]) -> Node
forall a b. (a -> b) -> a -> b
$ \[Edge]
es -> (Edge -> Edge) -> [Edge] -> [Edge]
forall a b. (a -> b) -> [a] -> [b]
map (EqConstraints -> Edge -> Edge
reduceChildren EqConstraints
inheritedEcs)
                                                       ([Edge] -> [Edge]) -> [Edge] -> [Edge]
forall a b. (a -> b) -> a -> b
$ (Edge -> Edge) -> [Edge] -> [Edge]
forall a b. (a -> b) -> [a] -> [b]
map (EqConstraints -> Edge -> Edge
reduceEdgeIntersection EqConstraints
inheritedEcs) [Edge]
es
    go EqConstraints
_            (Rec RecNodeId
_)    = [Char] -> Node
forall a. HasCallStack => [Char] -> a
error [Char]
"reducePartially: unexpected Rec"

    reduceChildren :: EqConstraints -> Edge -> Edge
    reduceChildren :: EqConstraints -> Edge -> Edge
reduceChildren EqConstraints
inheritedEcs Edge
e = Edge -> [Node] -> Edge
setChildren Edge
e ([Node] -> Edge) -> [Node] -> Edge
forall a b. (a -> b) -> a -> b
$ EqConstraints -> [Node] -> [Node]
reduceWithInheritedEcs (EqConstraints
inheritedEcs EqConstraints -> EqConstraints -> EqConstraints
`combineEqConstraints` Edge -> EqConstraints
edgeEcs Edge
e) (Edge -> [Node]
edgeChildren Edge
e)

    -- | Reduce children with inherited constraints
    --
    -- This function is used to avoid infinite unfolding of recursive nodes,
    -- and we do this by passing constraints from the current edge and ancestors to descendants.
    -- For example, let `tau` be "any" node, and we define
    --
    -- > let n1 = Node [ mkEdge "Pair" [tau, tau] (mkEqConstraints [[path [0, 0], path [0, 1], path [1]]])]
    -- > let n2 = Node [ Edge "Pair" [tau, tau] ]
    -- > let n  = Node [ mkEdge "Pair" [n1, n2]   (mkEqConstraints [[path [0, 0], path [0, 1], path [1]]])]
    --
    -- We notice that, if we call `reducePartially n` without propagating constraints down to its children `n1` or `n2`,
    -- the `tau` can be infinitely expanded between rounds of reduction.
    --
    -- To break such cycles, we actively pass constraints down to children.
    -- In this example, we first call `reducePartially' EmptyConstraints n` at the top level, where the inherited constraint is empty,
    -- so we only need to consider the constraints from the current edge.
    -- Then, we pass the constraints `0.0=0.1=1` down to its children, and `n1` receives `0=1` and `n2` receives nothing.
    -- Next, we reduce the children of `n` by calling `reducePartially' (mkEqConstraints [[path [0], path [1]]]) n1`.
    -- At this node, we will have to combine the inherited constraints `0=1` and the local constraints `0.0=0.1=1`.
    -- Now, we can see that these two constraints contain a contradiction that requires `0=0.0=0.1`, so we can drop the edge.
    --
    -- TODO: this approach does not solve all cases of cycles. See the test case `loop2` in `src/Application/TermSearch/Utils.hs`.
    reduceWithInheritedEcs :: EqConstraints -> [Node] -> [Node]
    reduceWithInheritedEcs :: EqConstraints -> [Node] -> [Node]
reduceWithInheritedEcs EqConstraints
EqContradiction [Node]
children = (Node -> Node) -> [Node] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map (Node -> Node -> Node
forall a b. a -> b -> a
const Node
EmptyNode) [Node]
children
    reduceWithInheritedEcs EqConstraints
inheritedEcs    [Node]
children = (Int -> Node -> Node) -> [Int] -> [Node] -> [Node]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> EqConstraints -> Node -> Node
reducePartially' (EqConstraints -> Int -> EqConstraints
eqConstraintsDescend EqConstraints
inheritedEcs Int
i)) [Int
0..] [Node]
children

{-# NOINLINE reducePartially' #-}

reduceEdgeIntersection :: EqConstraints -> Edge -> Edge
reduceEdgeIntersection :: EqConstraints -> Edge -> Edge
reduceEdgeIntersection = MemoCacheTag
-> (EqConstraints -> Edge -> Edge) -> EqConstraints -> Edge -> Edge
forall a b c.
(Eq a, Hashable a, Eq b, Hashable b) =>
MemoCacheTag -> (a -> b -> c) -> a -> b -> c
memo2 (Text -> MemoCacheTag
NameTag Text
"reduceEdgeIntersection") EqConstraints -> Edge -> Edge
go
  where
   go :: EqConstraints -> Edge -> Edge
   go :: EqConstraints -> Edge -> Edge
go EqConstraints
ecs Edge
e = Symbol -> [Node] -> EqConstraints -> Edge
mkEdge (Edge -> Symbol
edgeSymbol Edge
e)
                     (EqConstraints -> EqConstraints -> [Node] -> [Node]
reduceEqConstraints (Edge -> EqConstraints
edgeEcs Edge
e) EqConstraints
ecs (Edge -> [Node]
edgeChildren Edge
e))
                     (Edge -> EqConstraints
edgeEcs Edge
e)
{-# NOINLINE reduceEdgeIntersection #-}

reduceEqConstraints :: EqConstraints -> EqConstraints -> [Node] -> [Node]
reduceEqConstraints :: EqConstraints -> EqConstraints -> [Node] -> [Node]
reduceEqConstraints = EqConstraints -> EqConstraints -> [Node] -> [Node]
go
  where
    propagateEmptyNodes :: [Node] -> [Node]
    propagateEmptyNodes :: [Node] -> [Node]
propagateEmptyNodes [Node]
ns = if Node
EmptyNode Node -> [Node] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Node]
ns then (Node -> Node) -> [Node] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map (Node -> Node -> Node
forall a b. a -> b -> a
const Node
EmptyNode) [Node]
ns else [Node]
ns

    go :: EqConstraints -> EqConstraints -> [Node] -> [Node]
    go :: EqConstraints -> EqConstraints -> [Node] -> [Node]
go EqConstraints
ecs EqConstraints
inheritedEcs [Node]
origNs
      | EqConstraints -> Bool
constraintsAreContradictory (EqConstraints
ecs EqConstraints -> EqConstraints -> EqConstraints
`combineEqConstraints` EqConstraints
inheritedEcs) = (Node -> Node) -> [Node] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map (Node -> Node -> Node
forall a b. a -> b -> a
const Node
EmptyNode) [Node]
origNs
      | Bool
otherwise                                                             = [Node] -> [Node]
propagateEmptyNodes ([Node] -> [Node]) -> [Node] -> [Node]
forall a b. (a -> b) -> a -> b
$ (PathEClass -> [Node] -> [Node])
-> [Node] -> [PathEClass] -> [Node]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr PathEClass -> [Node] -> [Node]
reduceEClass [Node]
withNeededChildren [PathEClass]
eclasses
      where
        eclasses :: [PathEClass]
eclasses = EqConstraints -> [PathEClass]
unsafeSubsumptionOrderedEclasses EqConstraints
ecs

        -- | TODO: Replace with a "requirePathTrie"
        withNeededChildren :: [Node]
withNeededChildren = (Path -> [Node] -> [Node]) -> [Node] -> [Path] -> [Node]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Path -> [Node] -> [Node]
requirePathList [Node]
origNs ((PathEClass -> [Path]) -> [PathEClass] -> [Path]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap PathEClass -> [Path]
unPathEClass [PathEClass]
eclasses)

        intersectList :: [Node] -> Node
        intersectList :: [Node] -> Node
intersectList [Node]
ns = (Node -> Node -> Node) -> Node -> [Node] -> Node
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Node -> Node -> Node
intersect ([Node] -> Node
forall a. [a] -> a
head [Node]
ns) ([Node] -> [Node]
forall a. [a] -> [a]
tail [Node]
ns)

        _atPaths :: [Node] -> [Path] -> [Node]
        _atPaths :: [Node] -> [Path] -> [Node]
_atPaths [Node]
ns [Path]
ps = (Path -> Node) -> [Path] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map (\Path
p -> Path -> [Node] -> Emptyable Node
forall t t'. Pathable t t' => Path -> t -> Emptyable t'
getPath Path
p [Node]
ns) [Path]
ps

        reduceEClass :: PathEClass -> [Node] -> [Node]
        reduceEClass :: PathEClass -> [Node] -> [Node]
reduceEClass PathEClass
pec [Node]
ns = ((Path, Node) -> [Node] -> [Node])
-> [Node] -> [(Path, Node)] -> [Node]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Path
p, Node
nsRestIntersected) [Node]
ns' -> (Node -> Node) -> Path -> [Node] -> [Node]
forall t t'. Pathable t t' => (t' -> t') -> Path -> t -> t
modifyAtPath (Node -> Node -> Node
intersect Node
nsRestIntersected) Path
p [Node]
ns')
                                    [Node]
ns
                                    ([Path] -> [Node] -> [(Path, Node)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Path]
ps ([Node] -> [Path] -> [Node]
toIntersect [Node]
ns [Path]
ps))
          where
            ps :: [Path]
ps = PathEClass -> [Path]
unPathEClass PathEClass
pec

        toIntersect :: [Node] -> [Path] -> [Node]
        --toIntersect ns ps = replicate (length ps) $ intersectList $ map (nodeDropRedundantEdges . flip getPath ns) ps
        --toIntersect ns ps = map intersectList $ dropOnes $ map (nodeDropRedundantEdges . flip getPath ns) ps
        --toIntersect ns ps = replicate (length ps) $ intersectList $ map (flip getPath ns) ps
        toIntersect :: [Node] -> [Path] -> [Node]
toIntersect [Node]
ns [Path]
ps = ([Node] -> Node) -> [[Node]] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map [Node] -> Node
intersectList ([[Node]] -> [Node]) -> [[Node]] -> [Node]
forall a b. (a -> b) -> a -> b
$ [Node] -> [[Node]]
forall a. [a] -> [[a]]
dropOnes ([Node] -> [[Node]]) -> [Node] -> [[Node]]
forall a b. (a -> b) -> a -> b
$ (Path -> Node) -> [Path] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map (Path -> [Node] -> Emptyable Node
forall t t'. Pathable t t' => Path -> t -> Emptyable t'
`getPath` [Node]
ns) [Path]
ps

        -- | dropOnes [1,2,3,4] = [[2,3,4], [1,3,4], [1,2,4], [1,2,3]]
        dropOnes :: [a] -> [[a]]
        dropOnes :: [a] -> [[a]]
dropOnes [a]
xs = ([a] -> [a] -> [a]) -> [[a]] -> [[a]] -> [[a]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
(++) ([a] -> [[a]]
forall a. [a] -> [[a]]
inits [a]
xs) ([[a]] -> [[a]]
forall a. [a] -> [a]
tail ([[a]] -> [[a]]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> a -> b
$ [a] -> [[a]]
forall a. [a] -> [[a]]
tails [a]
xs)

---------------
--- Debugging
---------------

getSubnodeById :: Node -> Id -> Maybe Node
getSubnodeById :: Node -> Int -> Maybe Node
getSubnodeById Node
n Int
i = First Node -> Maybe Node
forall a. First a -> Maybe a
getFirst (First Node -> Maybe Node) -> First Node -> Maybe Node
forall a b. (a -> b) -> a -> b
$ (Node -> First Node) -> Node -> First Node
forall m. Monoid m => (Node -> m) -> Node -> m
crush ((Node -> First Node) -> Node -> First Node
forall m. Monoid m => (Node -> m) -> Node -> m
onNormalNodes ((Node -> First Node) -> Node -> First Node)
-> (Node -> First Node) -> Node -> First Node
forall a b. (a -> b) -> a -> b
$ \Node
x -> if Node -> Int
nodeIdentity Node
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i then Maybe Node -> First Node
forall a. Maybe a -> First a
First (Node -> Maybe Node
forall a. a -> Maybe a
Just Node
x) else Maybe Node -> First Node
forall a. Maybe a -> First a
First Maybe Node
forall a. Maybe a
Nothing) Node
n