{-# LANGUAGE MultiWayIf        #-}
{-# LANGUAGE OverloadedStrings #-}

module Data.ECTA.Internal.ECTA.Type (
    RecNodeId(..)

  , Edge(.., Edge)
  , UninternedEdge(..)
  , mkEdge
  , emptyEdge
  , edgeChildren
  , edgeEcs
  , edgeSymbol
  , setChildren

  , Node(.., Node, Mu)
  , InternedNode(..)
  , InternedMu(..)
  , UninternedNode(..)
  , IntersectId -- opaque
  , pattern IntersectId
  , nodeIdentity
  , numNestedMu
  , substFree
  , freeVars
  , modifyNode
  , createMu
  ) where

import Data.Function ( on )
import Data.Hashable ( Hashable(..) )
import Data.List ( sort )
import Data.Maybe ( fromMaybe )
import           Data.Map.Strict ( Map )
import qualified Data.Map.Strict as Map
import           Data.Set ( Set )
import qualified Data.Set as Set

import GHC.Generics ( Generic )

import System.IO.Unsafe ( unsafePerformIO )

import Data.List.Extra ( nubSort )

--   Switch the comments on these lines to switch to ekmett's original `intern` library
--   instead of our single-threaded hashtable-based reimplementation.
import Data.Interned.Extended.HashTableBased

-- NOTE 2/7/2022: This version is likely to break because there are nested calls to intern
--                for Mu nodes. See related comment in HashTableBased.hs
--import Data.Interned ( Interned(..), unintern, Id, Cache, mkCache )
--import Data.Interned.Extended.SingleThreaded ( intern )

import Data.ECTA.Internal.Paths
import Data.ECTA.Internal.Term


import Data.Memoization

---------------------------------------------------------------------------------------------

-----------------------------------------------------------------
-------------------------- Mu node table ------------------------
-----------------------------------------------------------------

data RecNodeId =
    -- | Reference to the 'Id' of an interned 'Mu' node
    RecInt !Id

    -- | Reference to an as-yet uninterned 'Mu' node, for which the 'Id' is not yet known
    --
    -- The 'Int' argument is used to distinguish between multiple nested 'Mu' nodes.
    --
    -- NOTE: This is intentionally not an 'Id': it does not refer to the 'Id' of any interned node.
  | RecUnint Int

    -- | Placeholder variable that we use /only/ for depth calculations
    --
    -- The invariant that this is used /only/ for depth calculations, along with the observation that depth calculation
    -- does not depend on the exact choice of variable, justifies subtituting any other variable for 'RecDepth' in terms
    -- containing 'RecDepth' in all contexts.
  | RecDepth

    -- | Refer to Mu-node-to-be-constructed during intersection
    --
    -- TODO: It is obviously not very elegant to have a constructor here specifically for one algorithm. Ideally, we
    -- would parameterize 'Node' with the type of the identifiers in it. This might be useful also to rule out many
    -- other cases (specifically, most of the time we are dealing with fully interned nodes, and so the only
    -- constructor we expect is 'RecInt').
  | RecIntersect IntersectId
  deriving ( RecNodeId -> RecNodeId -> Bool
(RecNodeId -> RecNodeId -> Bool)
-> (RecNodeId -> RecNodeId -> Bool) -> Eq RecNodeId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RecNodeId -> RecNodeId -> Bool
$c/= :: RecNodeId -> RecNodeId -> Bool
== :: RecNodeId -> RecNodeId -> Bool
$c== :: RecNodeId -> RecNodeId -> Bool
Eq, Eq RecNodeId
Eq RecNodeId
-> (RecNodeId -> RecNodeId -> Ordering)
-> (RecNodeId -> RecNodeId -> Bool)
-> (RecNodeId -> RecNodeId -> Bool)
-> (RecNodeId -> RecNodeId -> Bool)
-> (RecNodeId -> RecNodeId -> Bool)
-> (RecNodeId -> RecNodeId -> RecNodeId)
-> (RecNodeId -> RecNodeId -> RecNodeId)
-> Ord RecNodeId
RecNodeId -> RecNodeId -> Bool
RecNodeId -> RecNodeId -> Ordering
RecNodeId -> RecNodeId -> RecNodeId
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RecNodeId -> RecNodeId -> RecNodeId
$cmin :: RecNodeId -> RecNodeId -> RecNodeId
max :: RecNodeId -> RecNodeId -> RecNodeId
$cmax :: RecNodeId -> RecNodeId -> RecNodeId
>= :: RecNodeId -> RecNodeId -> Bool
$c>= :: RecNodeId -> RecNodeId -> Bool
> :: RecNodeId -> RecNodeId -> Bool
$c> :: RecNodeId -> RecNodeId -> Bool
<= :: RecNodeId -> RecNodeId -> Bool
$c<= :: RecNodeId -> RecNodeId -> Bool
< :: RecNodeId -> RecNodeId -> Bool
$c< :: RecNodeId -> RecNodeId -> Bool
compare :: RecNodeId -> RecNodeId -> Ordering
$ccompare :: RecNodeId -> RecNodeId -> Ordering
$cp1Ord :: Eq RecNodeId
Ord, Int -> RecNodeId -> ShowS
[RecNodeId] -> ShowS
RecNodeId -> String
(Int -> RecNodeId -> ShowS)
-> (RecNodeId -> String)
-> ([RecNodeId] -> ShowS)
-> Show RecNodeId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RecNodeId] -> ShowS
$cshowList :: [RecNodeId] -> ShowS
show :: RecNodeId -> String
$cshow :: RecNodeId -> String
showsPrec :: Int -> RecNodeId -> ShowS
$cshowsPrec :: Int -> RecNodeId -> ShowS
Show, (forall x. RecNodeId -> Rep RecNodeId x)
-> (forall x. Rep RecNodeId x -> RecNodeId) -> Generic RecNodeId
forall x. Rep RecNodeId x -> RecNodeId
forall x. RecNodeId -> Rep RecNodeId x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep RecNodeId x -> RecNodeId
$cfrom :: forall x. RecNodeId -> Rep RecNodeId x
Generic )

-- | Context-free references to a 'Mu' node introduced by 'intersect'
--
-- Background: This is a generalization of the idea to be able to refer to the "immediately enclosing binder", and then
-- only deal with graphs with the property that we never need to refer past that enclosing binder. This too would allow
-- us to refer to a 'Mu' node without knowing its 'Id', at the cost of requiring a substitution when we discover that
-- 'Id' to return this into a 'RecInt'. The generalization is that all we need to /some/ way to refer to that 'Mu' node
-- concretely, without 'Id', but we can: intersection introduces 'Mu' whenever it encounters a 'Mu' on the left or the
-- right, /and will then not introduce another 'Mu' for that same intersection problem (at least, not in the same
-- scope). This means that the 'Id' of the left and right operand will indeed uniquely identify the 'Mu' node to be
-- constructed by 'intersect'.
--
-- Furthermore, since we cache the free variables in a term, we have a cheap check to see if we need the 'Mu' node at
-- all. This means that /if/ the input graphs satisfy the property that there are references past 'Mu' nodes, the output
-- should too: we will not introduce redundant 'Mu' nodes.
--
-- NOTE: Although intersect has three cases in which it introduces 'Mu' nodes ('Mu' in both operands, 'Mu' in the left,
-- or 'Mu' in the right), we don't need that distinction here: we just need to know the 'Id' of the two operands, so
-- that if we see a call to intersect again /with those same two operands/ (no matter what kind of nodes they are), we
-- can refer to the newly constructed 'Mu' node.
data IntersectId =
     -- Invariant: the two 'Id's should be ordered (guaranteed by the pattern synonym constructor)
     UnsafeIntersectId !Id !Id
  deriving ( IntersectId -> IntersectId -> Bool
(IntersectId -> IntersectId -> Bool)
-> (IntersectId -> IntersectId -> Bool) -> Eq IntersectId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IntersectId -> IntersectId -> Bool
$c/= :: IntersectId -> IntersectId -> Bool
== :: IntersectId -> IntersectId -> Bool
$c== :: IntersectId -> IntersectId -> Bool
Eq, Eq IntersectId
Eq IntersectId
-> (IntersectId -> IntersectId -> Ordering)
-> (IntersectId -> IntersectId -> Bool)
-> (IntersectId -> IntersectId -> Bool)
-> (IntersectId -> IntersectId -> Bool)
-> (IntersectId -> IntersectId -> Bool)
-> (IntersectId -> IntersectId -> IntersectId)
-> (IntersectId -> IntersectId -> IntersectId)
-> Ord IntersectId
IntersectId -> IntersectId -> Bool
IntersectId -> IntersectId -> Ordering
IntersectId -> IntersectId -> IntersectId
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: IntersectId -> IntersectId -> IntersectId
$cmin :: IntersectId -> IntersectId -> IntersectId
max :: IntersectId -> IntersectId -> IntersectId
$cmax :: IntersectId -> IntersectId -> IntersectId
>= :: IntersectId -> IntersectId -> Bool
$c>= :: IntersectId -> IntersectId -> Bool
> :: IntersectId -> IntersectId -> Bool
$c> :: IntersectId -> IntersectId -> Bool
<= :: IntersectId -> IntersectId -> Bool
$c<= :: IntersectId -> IntersectId -> Bool
< :: IntersectId -> IntersectId -> Bool
$c< :: IntersectId -> IntersectId -> Bool
compare :: IntersectId -> IntersectId -> Ordering
$ccompare :: IntersectId -> IntersectId -> Ordering
$cp1Ord :: Eq IntersectId
Ord, Int -> IntersectId -> ShowS
[IntersectId] -> ShowS
IntersectId -> String
(Int -> IntersectId -> ShowS)
-> (IntersectId -> String)
-> ([IntersectId] -> ShowS)
-> Show IntersectId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IntersectId] -> ShowS
$cshowList :: [IntersectId] -> ShowS
show :: IntersectId -> String
$cshow :: IntersectId -> String
showsPrec :: Int -> IntersectId -> ShowS
$cshowsPrec :: Int -> IntersectId -> ShowS
Show, (forall x. IntersectId -> Rep IntersectId x)
-> (forall x. Rep IntersectId x -> IntersectId)
-> Generic IntersectId
forall x. Rep IntersectId x -> IntersectId
forall x. IntersectId -> Rep IntersectId x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep IntersectId x -> IntersectId
$cfrom :: forall x. IntersectId -> Rep IntersectId x
Generic )

pattern IntersectId :: Id -> Id -> IntersectId
pattern $bIntersectId :: Int -> Int -> IntersectId
$mIntersectId :: forall r. IntersectId -> (Int -> Int -> r) -> (Void# -> r) -> r
IntersectId i j <- (UnsafeIntersectId i j)
  where
    IntersectId Int
i Int
j | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
j    = Int -> Int -> IntersectId
UnsafeIntersectId Int
i Int
j
                    | Bool
otherwise = Int -> Int -> IntersectId
UnsafeIntersectId Int
j Int
i

instance Hashable RecNodeId
instance Hashable IntersectId

-----------------------------------------------------------------
----------------------------- Edges -----------------------------
-----------------------------------------------------------------

data Edge = InternedEdge { Edge -> Int
edgeId         :: !Id
                         , Edge -> UninternedEdge
uninternedEdge :: !UninternedEdge
                         }

instance Show Edge where
  show :: Edge -> String
show Edge
e | Edge -> EqConstraints
edgeEcs Edge
e EqConstraints -> EqConstraints -> Bool
forall a. Eq a => a -> a -> Bool
== EqConstraints
EmptyConstraints = String
"(Edge " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Symbol -> String
forall a. Show a => a -> String
show (Edge -> Symbol
edgeSymbol Edge
e) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Node] -> String
forall a. Show a => a -> String
show (Edge -> [Node]
edgeChildren Edge
e) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
         | Bool
otherwise                     = String
"(mkEdge " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Symbol -> String
forall a. Show a => a -> String
show (Edge -> Symbol
edgeSymbol Edge
e) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Node] -> String
forall a. Show a => a -> String
show (Edge -> [Node]
edgeChildren Edge
e) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ EqConstraints -> String
forall a. Show a => a -> String
show (Edge -> EqConstraints
edgeEcs Edge
e) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

--instance Show Edge where
--  show e = "InternedEdge " ++ show (edgeId e) ++ " " ++ show (edgeSymbol e) ++ " " ++ show (edgeChildren e) ++ " " ++ show (edgeEcs e)

edgeSymbol :: Edge -> Symbol
edgeSymbol :: Edge -> Symbol
edgeSymbol = UninternedEdge -> Symbol
uEdgeSymbol (UninternedEdge -> Symbol)
-> (Edge -> UninternedEdge) -> Edge -> Symbol
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Edge -> UninternedEdge
uninternedEdge

edgeChildren :: Edge -> [Node]
edgeChildren :: Edge -> [Node]
edgeChildren = UninternedEdge -> [Node]
uEdgeChildren (UninternedEdge -> [Node])
-> (Edge -> UninternedEdge) -> Edge -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Edge -> UninternedEdge
uninternedEdge

edgeEcs :: Edge -> EqConstraints
edgeEcs :: Edge -> EqConstraints
edgeEcs = UninternedEdge -> EqConstraints
uEdgeEcs (UninternedEdge -> EqConstraints)
-> (Edge -> UninternedEdge) -> Edge -> EqConstraints
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Edge -> UninternedEdge
uninternedEdge

instance Eq Edge where
  (InternedEdge {edgeId :: Edge -> Int
edgeId = Int
n1}) == :: Edge -> Edge -> Bool
== (InternedEdge {edgeId :: Edge -> Int
edgeId = Int
n2}) = Int
n1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n2

instance Ord Edge where
  compare :: Edge -> Edge -> Ordering
compare = Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> (Edge -> Int) -> Edge -> Edge -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Edge -> Int
edgeId

instance Hashable Edge where
  hashWithSalt :: Int -> Edge -> Int
hashWithSalt Int
s Edge
e = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Edge -> Int
edgeId Edge
e)


-----------------------------------------------------------------
------------------------------ Nodes ----------------------------
-----------------------------------------------------------------

data InternedMu = MkInternedMu {
      -- | 'Id' of the node itself
      InternedMu -> Int
internedMuId :: {-# UNPACK #-} !Id

      -- | The body of the 'Mu'
      --
      -- Recursive occurrences to this node should be
      --
      -- > Rec (RecNodeId internedMuId)
    , InternedMu -> Node
internedMuBody :: !Node

      -- | The body of the 'Mu', before it was assigned an 'Id'
      --
      -- Invariant:
      --
      -- >    substFree internedMuId (Rec (RecUnint (numNestedMu internedMuBody)) internedMuBody
      -- > == internedMuShape
    , InternedMu -> Node
internedMuShape :: !Node
    }
  deriving (Int -> InternedMu -> ShowS
[InternedMu] -> ShowS
InternedMu -> String
(Int -> InternedMu -> ShowS)
-> (InternedMu -> String)
-> ([InternedMu] -> ShowS)
-> Show InternedMu
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [InternedMu] -> ShowS
$cshowList :: [InternedMu] -> ShowS
show :: InternedMu -> String
$cshow :: InternedMu -> String
showsPrec :: Int -> InternedMu -> ShowS
$cshowsPrec :: Int -> InternedMu -> ShowS
Show)

data InternedNode = MkInternedNode {
      -- | The 'Id' of the node itself
      InternedNode -> Int
internedNodeId :: {-# UNPACK #-} !Id

      -- | All outgoing edges
    , InternedNode -> [Edge]
internedNodeEdges :: ![Edge]

      -- | Maximum Mu nesting depth in the term
    , InternedNode -> Int
internedNodeNumNestedMu :: !Int

      -- | Free variables in the term
    , InternedNode -> Set RecNodeId
internedNodeFree :: !(Set RecNodeId)
    }
  deriving (Int -> InternedNode -> ShowS
[InternedNode] -> ShowS
InternedNode -> String
(Int -> InternedNode -> ShowS)
-> (InternedNode -> String)
-> ([InternedNode] -> ShowS)
-> Show InternedNode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [InternedNode] -> ShowS
$cshowList :: [InternedNode] -> ShowS
show :: InternedNode -> String
$cshow :: InternedNode -> String
showsPrec :: Int -> InternedNode -> ShowS
$cshowsPrec :: Int -> InternedNode -> ShowS
Show)

data Node = InternedNode {-# UNPACK #-} !InternedNode
          | EmptyNode
          | InternedMu {-# UNPACK #-} !InternedMu
          | Rec !RecNodeId

instance Eq Node where
  InternedNode InternedNode
l == :: Node -> Node -> Bool
== InternedNode InternedNode
r = InternedNode -> Int
internedNodeId InternedNode
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== InternedNode -> Int
internedNodeId InternedNode
r
  InternedMu   InternedMu
l == InternedMu   InternedMu
r = InternedMu -> Int
internedMuId   InternedMu
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== InternedMu -> Int
internedMuId   InternedMu
r
  Rec          RecNodeId
l == Rec          RecNodeId
r =                RecNodeId
l RecNodeId -> RecNodeId -> Bool
forall a. Eq a => a -> a -> Bool
==                RecNodeId
r
  Node
EmptyNode      == Node
EmptyNode      = Bool
True
  Node
_              == Node
_              = Bool
False

instance Show Node where
  show :: Node -> String
show (InternedNode InternedNode
node) = String
"(Node " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Edge] -> String
forall a. Show a => a -> String
show (InternedNode -> [Edge]
internedNodeEdges InternedNode
node) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
")"
  show Node
EmptyNode           = String
"EmptyNode"
  show (InternedMu InternedMu
mu)     = String
"(Mu " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (InternedMu -> Int
internedMuId InternedMu
mu) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Node -> String
forall a. Show a => a -> String
show (InternedMu -> Node
internedMuBody InternedMu
mu) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
")"
  show (Rec RecNodeId
n)             = String
"(Rec " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> RecNodeId -> String
forall a. Show a => a -> String
show RecNodeId
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
")"

instance Ord Node where
  compare :: Node -> Node -> Ordering
compare Node
n1 Node
n2 = Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Node -> Int
nodeDescriptorInt Node
n1) (Node -> Int
nodeDescriptorInt Node
n2)
    where
      nodeDescriptorInt :: Node -> Int
      nodeDescriptorInt :: Node -> Int
nodeDescriptorInt Node
EmptyNode           = -Int
1
      nodeDescriptorInt (InternedNode InternedNode
node) = Int
3Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
        where
          i :: Int
i = InternedNode -> Int
internedNodeId InternedNode
node
      nodeDescriptorInt (InternedMu InternedMu
mu)     = Int
3Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        where
          i :: Int
i = InternedMu -> Int
internedMuId InternedMu
mu
      nodeDescriptorInt (Rec RecNodeId
recId)         = Int
3Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2
        where
          i :: Int
i = case RecNodeId
recId of
                RecInt Int
nid -> Int
nid
                RecNodeId
_otherwise -> String -> Int
forall a. HasCallStack => String -> a
error (String -> Int) -> String -> Int
forall a b. (a -> b) -> a -> b
$ String
"compare: unexpected " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> RecNodeId -> String
forall a. Show a => a -> String
show RecNodeId
recId


instance Hashable Node where
  hashWithSalt :: Int -> Node -> Int
hashWithSalt Int
s Node
EmptyNode           = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (-Int
1 :: Int)
  hashWithSalt Int
s (InternedMu InternedMu
mu)     = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (-Int
2 :: Int) Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` Int
i
    where
      i :: Int
i = InternedMu -> Int
internedMuId InternedMu
mu
  hashWithSalt Int
s (Rec RecNodeId
i)             = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (-Int
3 :: Int) Int -> RecNodeId -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` RecNodeId
i
  hashWithSalt Int
s (InternedNode InternedNode
node) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` Int
i
    where
      i :: Int
i = InternedNode -> Int
internedNodeId InternedNode
node

-- | Maximum number of nested Mus in the term
--
-- @O(1) provided that there are no unbounded Mu chains in the term.
numNestedMu :: Node -> Int
numNestedMu :: Node -> Int
numNestedMu Node
EmptyNode           = Int
0
numNestedMu (InternedNode InternedNode
node) = InternedNode -> Int
internedNodeNumNestedMu InternedNode
node
numNestedMu (InternedMu   InternedMu
mu)   = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Node -> Int
numNestedMu (InternedMu -> Node
internedMuBody InternedMu
mu)
numNestedMu (Rec RecNodeId
_)             = Int
0

-- | Free variables in the term
--
-- @O(1) in the size of the graph, provided that there are no unbounded Mu chains in the term.
-- @O(log n)@ in the number of free variables in the graph, which we expect to be orders of magnitude smaller than the
-- size of the graph (indeed, we don't expect more than a handful).
freeVars :: Node -> Set RecNodeId
freeVars :: Node -> Set RecNodeId
freeVars Node
EmptyNode           = Set RecNodeId
forall a. Set a
Set.empty
freeVars (InternedNode InternedNode
node) = InternedNode -> Set RecNodeId
internedNodeFree InternedNode
node
freeVars (InternedMu   InternedMu
mu)   = RecNodeId -> Set RecNodeId -> Set RecNodeId
forall a. Ord a => a -> Set a -> Set a
Set.delete (Int -> RecNodeId
RecInt (InternedMu -> Int
internedMuId InternedMu
mu)) (Node -> Set RecNodeId
freeVars (InternedMu -> Node
internedMuBody InternedMu
mu))
freeVars (Rec RecNodeId
i)             = RecNodeId -> Set RecNodeId
forall a. a -> Set a
Set.singleton RecNodeId
i

----------------------
------ Getters and setters
----------------------

nodeIdentity :: Node -> Id
nodeIdentity :: Node -> Int
nodeIdentity (InternedMu   InternedMu
mu)   = InternedMu -> Int
internedMuId InternedMu
mu
nodeIdentity (InternedNode InternedNode
node) = InternedNode -> Int
internedNodeId InternedNode
node
nodeIdentity (Rec (RecInt Int
i))    = Int
i
nodeIdentity Node
n                   = String -> Int
forall a. HasCallStack => String -> a
error (String -> Int) -> String -> Int
forall a b. (a -> b) -> a -> b
$ String
"nodeIdentity: unexpected node " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Node -> String
forall a. Show a => a -> String
show Node
n

setChildren :: Edge -> [Node] -> Edge
setChildren :: Edge -> [Node] -> Edge
setChildren Edge
e [Node]
ns = Symbol -> [Node] -> EqConstraints -> Edge
mkEdge (Edge -> Symbol
edgeSymbol Edge
e) [Node]
ns (Edge -> EqConstraints
edgeEcs Edge
e)

_dropEcs :: Edge -> Edge
_dropEcs :: Edge -> Edge
_dropEcs Edge
e = Symbol -> [Node] -> Edge
Edge (Edge -> Symbol
edgeSymbol Edge
e) (Edge -> [Node]
edgeChildren Edge
e)


-----------------------------------------------------------------
------------------------- Interning Nodes -----------------------
-----------------------------------------------------------------

data UninternedNode =
      UninternedNode ![Edge]
    | UninternedEmptyNode

      -- | Recursive node
      --
      -- The function should be parametric in the Id:
      --
      -- > substFree i (Rec j) (f i) == f j
      --
      -- See 'shape' for additional discussion.
    | UninternedMu !(RecNodeId -> Node)

instance Eq UninternedNode where
  UninternedNode [Edge]
es   == :: UninternedNode -> UninternedNode -> Bool
== UninternedNode [Edge]
es'  = [Edge]
es [Edge] -> [Edge] -> Bool
forall a. Eq a => a -> a -> Bool
== [Edge]
es'
  UninternedNode
UninternedEmptyNode == UninternedNode
UninternedEmptyNode = Bool
True
  UninternedMu RecNodeId -> Node
mu     == UninternedMu RecNodeId -> Node
mu'    = (RecNodeId -> Node) -> Node
shape RecNodeId -> Node
mu Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== (RecNodeId -> Node) -> Node
shape RecNodeId -> Node
mu'
  UninternedNode
_                   == UninternedNode
_                   = Bool
False

instance Hashable UninternedNode where
  hashWithSalt :: Int -> UninternedNode -> Int
hashWithSalt Int
salt = UninternedNode -> Int
go
    where
      go :: UninternedNode -> Int
      go :: UninternedNode -> Int
go  UninternedNode
UninternedEmptyNode = Int -> (Int, ()) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt (Int
0 :: Int, ())
      go (UninternedNode [Edge]
es)  = Int -> (Int, [Edge]) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt (Int
1 :: Int, [Edge]
es)
      go (UninternedMu RecNodeId -> Node
mu)    = Int -> (Int, Node) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt (Int
2 :: Int, (RecNodeId -> Node) -> Node
shape RecNodeId -> Node
mu)

instance Interned Node where
  type Uninterned  Node = UninternedNode
  data Description Node = DNode !UninternedNode
    deriving ( Description Node -> Description Node -> Bool
(Description Node -> Description Node -> Bool)
-> (Description Node -> Description Node -> Bool)
-> Eq (Description Node)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Description Node -> Description Node -> Bool
$c/= :: Description Node -> Description Node -> Bool
== :: Description Node -> Description Node -> Bool
$c== :: Description Node -> Description Node -> Bool
Eq, (forall x. Description Node -> Rep (Description Node) x)
-> (forall x. Rep (Description Node) x -> Description Node)
-> Generic (Description Node)
forall x. Rep (Description Node) x -> Description Node
forall x. Description Node -> Rep (Description Node) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep (Description Node) x -> Description Node
$cfrom :: forall x. Description Node -> Rep (Description Node) x
Generic )

  describe :: Uninterned Node -> Description Node
describe = Uninterned Node -> Description Node
UninternedNode -> Description Node
DNode

  identify :: Int -> Uninterned Node -> Node
identify Int
i (UninternedNode es) = InternedNode -> Node
InternedNode (InternedNode -> Node) -> InternedNode -> Node
forall a b. (a -> b) -> a -> b
$ MkInternedNode :: Int -> [Edge] -> Int -> Set RecNodeId -> InternedNode
MkInternedNode {
        internedNodeId :: Int
internedNodeId          = Int
i
      , internedNodeEdges :: [Edge]
internedNodeEdges       = [Edge]
es
      , internedNodeNumNestedMu :: Int
internedNodeNumNestedMu = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Edge -> [Int]) -> [Edge] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Node -> Int) -> [Node] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Node -> Int
numNestedMu ([Node] -> [Int]) -> (Edge -> [Node]) -> Edge -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Edge -> [Node]
edgeChildren) [Edge]
es) -- depth is always >= 0
      , internedNodeFree :: Set RecNodeId
internedNodeFree        = [Set RecNodeId] -> Set RecNodeId
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions ((Edge -> [Set RecNodeId]) -> [Edge] -> [Set RecNodeId]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Node -> Set RecNodeId) -> [Node] -> [Set RecNodeId]
forall a b. (a -> b) -> [a] -> [b]
map Node -> Set RecNodeId
freeVars ([Node] -> [Set RecNodeId])
-> (Edge -> [Node]) -> Edge -> [Set RecNodeId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Edge -> [Node]
edgeChildren) [Edge]
es)
      }
  identify Int
_ Uninterned Node
UninternedEmptyNode = Node
EmptyNode
  identify Int
i (UninternedMu n)    = InternedMu -> Node
InternedMu (InternedMu -> Node) -> InternedMu -> Node
forall a b. (a -> b) -> a -> b
$ MkInternedMu :: Int -> Node -> Node -> InternedMu
MkInternedMu {
        internedMuId :: Int
internedMuId    = Int
i
      , internedMuBody :: Node
internedMuBody  = RecNodeId -> Node
n (Int -> RecNodeId
RecInt Int
i)

        -- In order to establish the invariant for internedMuNoId, we need to know
        --
        -- >    substFree internedMuId (Rec (RecUnint (numNestedMu internedMuBody)) internedMuBody
        -- > == internedMuShape
        --
        -- This follows from parametricity:
        --
        -- >    internedMuShape
        -- >      -- { definition of internedMuShape }
        -- > == shape n
        -- >      -- { definition of shape }
        -- > == n (RecUnint (numNestedMu (n RecDepth)))
        -- >      -- { by parametricity, depth is independent of the variable number }
        -- > == n (RecUnint (numNestedMu (n (RecInt i))))
        -- >      -- { parametricity again }
        -- > == substFree i (Rec (RecUnint (numNestedMu (n (RecInt i)))) (n (RecInt i))
        -- >      -- { definition of internedMuId and internedMuBody }
        -- > == substFree internedMuId (Rec (RecUnint (numNestedMu internedMuBody))) internedMuBody
        --
        -- QED.
      , internedMuShape :: Node
internedMuShape = (RecNodeId -> Node) -> Node
shape RecNodeId -> Node
n
      }

  cache :: Cache Node
cache = Cache Node
nodeCache

instance Hashable (Description Node)

nodeCache :: Cache Node
nodeCache :: Cache Node
nodeCache = IO (Cache Node) -> Cache Node
forall a. IO a -> a
unsafePerformIO IO (Cache Node)
forall t. IO (Cache t)
freshCache
{-# NOINLINE nodeCache #-}

-- | Compute the " shape " of the body of a 'Mu'
--
-- During interning we need to know the shape of the body of a 'Mu' node /before/ we know the 'Id' of that node. We do
-- this by replacing any 'Rec' nodes in the node by placeholders. We have to be careful here however to correctly assign
-- placeholders in the presence of nested 'Mu' nodes. For example, if the user writes a term such as
--
-- > -- f (f (f ... (g (g (g ... a)))))
-- > Mu $ \r -> Node [
-- >     Edge "f" [r]
-- >   , Edge "g" [ Mu $ \r' -> Node [
-- >                    Edge "g" [r']
-- >                  , Edge "a" []
-- >                  ]
-- >              ]
-- >   ]
--
-- we should be careful not to accidentially identify @r@ and @r'@.
--
-- Precondition: the function must be parametric in the choice of variable names:
--
-- > substFree i (Rec j) (f i) == f j
--
-- Put another way, we must rule out /exotic terms/: in our case, exotic terms would be uninterned @Mu@ nodes that
-- have one shape when given one variable, and another shape when given a different variable. We do not have such terms.
-- (Of course, a function such as substitution /does/ do one thing if it sees one variable and another thing when it
-- sees a different variable, but this is okay: substitution is a function /on/ terms, mapping non-exotic terms to
-- non-exotic terms.)
--
-- Implementation note: We are calling the function twice: once to compute the depth of the node, and then a second time
-- to give it the right placeholder variable. Some observations:
--
-- o Semantically, this is okay; if we were working with a first order representation, it would be the equivalent of
--   first executing some kind of function @Node -> Int@, followed by some kind of substitution @Node -> Node@. It's the
--   same with the higher order representation, except that in /principle/ the function could do entirely different
--   things when given 'RecDepth' versus some other kind of placeholder; the parametricity precondition rules this out.
-- o It's slightly inefficient, but since this lives at the user interface boundary only, performance here is not
--   critical: internally we work with interned nodes only, and this function is not relevant.
-- o It /is/ important that the placeholder we pick here is uniquely determined by the node itself: this is what
--   justifies using 'shape' during interning.
shape :: (RecNodeId -> Node) -> Node
shape :: (RecNodeId -> Node) -> Node
shape RecNodeId -> Node
f = RecNodeId -> Node
f (Int -> RecNodeId
RecUnint (Node -> Int
numNestedMu (RecNodeId -> Node
f RecNodeId
RecDepth)))

-----------------------------------------------------------------
------------------------ Interning Edges ------------------------
-----------------------------------------------------------------

data UninternedEdge = UninternedEdge { UninternedEdge -> Symbol
uEdgeSymbol    :: !Symbol
                                     , UninternedEdge -> [Node]
uEdgeChildren  :: ![Node]
                                     , UninternedEdge -> EqConstraints
uEdgeEcs       :: !EqConstraints
                                     }
  deriving ( UninternedEdge -> UninternedEdge -> Bool
(UninternedEdge -> UninternedEdge -> Bool)
-> (UninternedEdge -> UninternedEdge -> Bool) -> Eq UninternedEdge
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UninternedEdge -> UninternedEdge -> Bool
$c/= :: UninternedEdge -> UninternedEdge -> Bool
== :: UninternedEdge -> UninternedEdge -> Bool
$c== :: UninternedEdge -> UninternedEdge -> Bool
Eq, Int -> UninternedEdge -> ShowS
[UninternedEdge] -> ShowS
UninternedEdge -> String
(Int -> UninternedEdge -> ShowS)
-> (UninternedEdge -> String)
-> ([UninternedEdge] -> ShowS)
-> Show UninternedEdge
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UninternedEdge] -> ShowS
$cshowList :: [UninternedEdge] -> ShowS
show :: UninternedEdge -> String
$cshow :: UninternedEdge -> String
showsPrec :: Int -> UninternedEdge -> ShowS
$cshowsPrec :: Int -> UninternedEdge -> ShowS
Show, (forall x. UninternedEdge -> Rep UninternedEdge x)
-> (forall x. Rep UninternedEdge x -> UninternedEdge)
-> Generic UninternedEdge
forall x. Rep UninternedEdge x -> UninternedEdge
forall x. UninternedEdge -> Rep UninternedEdge x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep UninternedEdge x -> UninternedEdge
$cfrom :: forall x. UninternedEdge -> Rep UninternedEdge x
Generic )

instance Hashable UninternedEdge

instance Interned Edge where
  type Uninterned  Edge = UninternedEdge
  data Description Edge = DEdge {-# UNPACK #-} !UninternedEdge
    deriving ( Description Edge -> Description Edge -> Bool
(Description Edge -> Description Edge -> Bool)
-> (Description Edge -> Description Edge -> Bool)
-> Eq (Description Edge)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Description Edge -> Description Edge -> Bool
$c/= :: Description Edge -> Description Edge -> Bool
== :: Description Edge -> Description Edge -> Bool
$c== :: Description Edge -> Description Edge -> Bool
Eq, (forall x. Description Edge -> Rep (Description Edge) x)
-> (forall x. Rep (Description Edge) x -> Description Edge)
-> Generic (Description Edge)
forall x. Rep (Description Edge) x -> Description Edge
forall x. Description Edge -> Rep (Description Edge) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep (Description Edge) x -> Description Edge
$cfrom :: forall x. Description Edge -> Rep (Description Edge) x
Generic )

  describe :: Uninterned Edge -> Description Edge
describe = Uninterned Edge -> Description Edge
UninternedEdge -> Description Edge
DEdge

  identify :: Int -> Uninterned Edge -> Edge
identify Int
i Uninterned Edge
e = Int -> UninternedEdge -> Edge
InternedEdge Int
i Uninterned Edge
UninternedEdge
e

  cache :: Cache Edge
cache = Cache Edge
edgeCache

instance Hashable (Description Edge)

edgeCache :: Cache Edge
edgeCache :: Cache Edge
edgeCache = IO (Cache Edge) -> Cache Edge
forall a. IO a -> a
unsafePerformIO IO (Cache Edge)
forall t. IO (Cache t)
freshCache
{-# NOINLINE edgeCache #-}

-----------------------------------------------------------------
----------------------- Smart constructors ----------------------
-----------------------------------------------------------------

-------------------
------ Edge constructors
-------------------

pattern Edge :: Symbol -> [Node] -> Edge
pattern $bEdge :: Symbol -> [Node] -> Edge
$mEdge :: forall r. Edge -> (Symbol -> [Node] -> r) -> (Void# -> r) -> r
Edge s ns <- (InternedEdge _ (UninternedEdge s ns _)) where
  Edge Symbol
s [Node]
ns = Uninterned Edge -> Edge
forall t. Interned t => Uninterned t -> t
intern (Uninterned Edge -> Edge) -> Uninterned Edge -> Edge
forall a b. (a -> b) -> a -> b
$ Symbol -> [Node] -> EqConstraints -> UninternedEdge
UninternedEdge Symbol
s [Node]
ns EqConstraints
EmptyConstraints

{-# COMPLETE Edge #-}

emptyEdge :: Edge
emptyEdge :: Edge
emptyEdge = Symbol -> [Node] -> Edge
Edge Symbol
"" [Node
EmptyNode]

isEmptyEdge :: Edge -> Bool
isEmptyEdge :: Edge -> Bool
isEmptyEdge (Edge Symbol
_ [Node]
ns) = (Node -> Bool) -> [Node] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== Node
EmptyNode) [Node]
ns

removeEmptyEdges :: [Edge] -> [Edge]
removeEmptyEdges :: [Edge] -> [Edge]
removeEmptyEdges = (Edge -> Bool) -> [Edge] -> [Edge]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Edge -> Bool) -> Edge -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Edge -> Bool
isEmptyEdge)

mkEdge :: Symbol -> [Node] -> EqConstraints -> Edge
mkEdge :: Symbol -> [Node] -> EqConstraints -> Edge
mkEdge Symbol
_ [Node]
_  EqConstraints
ecs
   | EqConstraints -> Bool
constraintsAreContradictory EqConstraints
ecs = Edge
emptyEdge
mkEdge Symbol
s [Node]
ns EqConstraints
ecs
   | Bool
otherwise                       = Uninterned Edge -> Edge
forall t. Interned t => Uninterned t -> t
intern (Uninterned Edge -> Edge) -> Uninterned Edge -> Edge
forall a b. (a -> b) -> a -> b
$ Symbol -> [Node] -> EqConstraints -> UninternedEdge
UninternedEdge Symbol
s [Node]
ns EqConstraints
ecs


-------------------
------ Node constructors
-------------------

{-# COMPLETE Node, EmptyNode, Mu, Rec #-}

pattern Node :: [Edge] -> Node
pattern $bNode :: [Edge] -> Node
$mNode :: forall r. Node -> ([Edge] -> r) -> (Void# -> r) -> r
Node es <- (InternedNode (internedNodeEdges -> es)) where
  Node = [Edge] -> Node
mkNode

mkNode :: [Edge] -> Node
mkNode :: [Edge] -> Node
mkNode [Edge]
es = case [Edge] -> [Edge]
removeEmptyEdges [Edge]
es of
              []  -> Node
EmptyNode
              [Edge]
es' -> Uninterned Node -> Node
forall t. Interned t => Uninterned t -> t
intern (Uninterned Node -> Node) -> Uninterned Node -> Node
forall a b. (a -> b) -> a -> b
$ [Edge] -> UninternedNode
UninternedNode ([Edge] -> UninternedNode) -> [Edge] -> UninternedNode
forall a b. (a -> b) -> a -> b
$ [Edge] -> [Edge]
forall a. Ord a => [a] -> [a]
nubSort [Edge]
es'

_mkNodeAlreadyNubbed :: [Edge] -> Node
_mkNodeAlreadyNubbed :: [Edge] -> Node
_mkNodeAlreadyNubbed [Edge]
es = case [Edge] -> [Edge]
removeEmptyEdges [Edge]
es of
                            []  -> Node
EmptyNode
                            [Edge]
es' -> Uninterned Node -> Node
forall t. Interned t => Uninterned t -> t
intern (Uninterned Node -> Node) -> Uninterned Node -> Node
forall a b. (a -> b) -> a -> b
$ [Edge] -> UninternedNode
UninternedNode ([Edge] -> UninternedNode) -> [Edge] -> UninternedNode
forall a b. (a -> b) -> a -> b
$ [Edge] -> [Edge]
forall a. Ord a => [a] -> [a]
sort [Edge]
es'

-- | An optimized Node constructor that avoids the interning/preprocessing of the Node constructor
--   when nothing changes
modifyNode :: Node -> ([Edge] -> [Edge]) -> Node
modifyNode :: Node -> ([Edge] -> [Edge]) -> Node
modifyNode n :: Node
n@(Node [Edge]
es) [Edge] -> [Edge]
f = let es' :: [Edge]
es' = [Edge] -> [Edge]
f [Edge]
es in
                           if [Edge]
es' [Edge] -> [Edge] -> Bool
forall a. Eq a => a -> a -> Bool
== [Edge]
es then
                             Node
n
                           else
                             [Edge] -> Node
Node [Edge]
es'
modifyNode Node
n           [Edge] -> [Edge]
_ = String -> Node
forall a. HasCallStack => String -> a
error (String -> Node) -> String -> Node
forall a b. (a -> b) -> a -> b
$ String
"modifyNode: unexpected node " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Node -> String
forall a. Show a => a -> String
show Node
n

_collapseEmptyEdge :: Edge -> Maybe Edge
_collapseEmptyEdge :: Edge -> Maybe Edge
_collapseEmptyEdge e :: Edge
e@(Edge Symbol
_ [Node]
ns) = if (Node -> Bool) -> [Node] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== Node
EmptyNode) [Node]
ns then Maybe Edge
forall a. Maybe a
Nothing else Edge -> Maybe Edge
forall a. a -> Maybe a
Just Edge
e

------ Mu

-- | Pattern only a Mu constructor
--
-- When we go underneath a Mu constructor, we need to bind the corresponding Rec node to something: that's why pattern
-- matching on 'Mu' yields a function. Code that wants to traverse the term as-is should match on the interned
-- constructors instead (and then deal with the dangling references).
--
-- An identity function
--
-- > foo (Mu f) = Mu f
--
-- will run in O(1) time:
--
-- > foo (Mu f) = Mu f
-- >   -- { expand view patern }
-- > foo node | Just f <- matchMu node = createMu f
-- >   -- { case for @InternedMu mu@ }
-- > foo (InternedMu mu) | Just f <- matchMu (InternedMu m) = createMu f
-- >   -- { definition of matchMu }
-- > foo (InternedMu mu) = let f = \n' ->
-- >                          if | n' == Rec (RecUnint (numNestedMu (internedMuBody mu))) ->
-- >                                internedMuShape mu
-- >                            | n' == Rec RecDepth ->
-- >                                internedMuShape mu
-- >                            | otherwise ->
-- >                                substFree (internedMuId mu) n' (internedMuBody mu)
-- >                       in createMu f
-- >   -- { definition of createMu }
-- > foo (InternedMu mu) = intern $ UninternedMu (f . Rec)
--
-- At this point, `intern` will call `shape (f . Rec)`, which will call `f . Rec` twice: once with `RecDepth` to compute
-- the depth, and then once again with that depth to substitute a placeholder. Both of these special cases will use
-- 'internedMuShape' (and moreover, the depth calculation on 'internedMuShape' is @O(1)@).
pattern Mu :: (Node -> Node) -> Node
pattern $bMu :: (Node -> Node) -> Node
$mMu :: forall r. Node -> ((Node -> Node) -> r) -> (Void# -> r) -> r
Mu f <- (matchMu -> Just f)
  where
    Mu = (Node -> Node) -> Node
createMu

-- | Construct recursive node
--
-- Implementation note: 'createMu' and 'matchMu' interact in non-trivial ways; see docs of the 'Mu' pattern synonym
-- for performance considerations.
createMu :: (Node -> Node) -> Node
createMu :: (Node -> Node) -> Node
createMu Node -> Node
f = Uninterned Node -> Node
forall t. Interned t => Uninterned t -> t
intern (Uninterned Node -> Node) -> Uninterned Node -> Node
forall a b. (a -> b) -> a -> b
$ (RecNodeId -> Node) -> UninternedNode
UninternedMu (Node -> Node
f (Node -> Node) -> (RecNodeId -> Node) -> RecNodeId -> Node
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecNodeId -> Node
Rec)

-- | Match on a 'Mu' node
--
-- Implementation note: 'createMu' and 'matchMu' interact in non-trivial ways; see docs of the 'Mu' pattern synonym
-- for performance considerations.
matchMu :: Node -> Maybe (Node -> Node)
matchMu :: Node -> Maybe (Node -> Node)
matchMu (InternedMu InternedMu
mu) = (Node -> Node) -> Maybe (Node -> Node)
forall a. a -> Maybe a
Just ((Node -> Node) -> Maybe (Node -> Node))
-> (Node -> Node) -> Maybe (Node -> Node)
forall a b. (a -> b) -> a -> b
$ \Node
n' ->
    if | Node
n' Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== RecNodeId -> Node
Rec (Int -> RecNodeId
RecUnint (Node -> Int
numNestedMu (InternedMu -> Node
internedMuBody InternedMu
mu))) ->
          -- Special case justified by the invariant on 'internedMuShape'
          InternedMu -> Node
internedMuShape InternedMu
mu
       | Node
n' Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== RecNodeId -> Node
Rec RecNodeId
RecDepth ->
          -- The use of 'RecDepth' implies that we are computing a depth:
          --
          -- >    numNestedMu (substFree (internedMuId mu) (Rec RecDepth)) (internedMuBody mu))
          -- >      -- { depth calculation does not depend on choice of variable }
          -- > == numNestedMu (substFree (internedMuId mu) Rec (RecUnint (numNestedMu (internedMuBody mu)))) (internedMuBody mu))
          -- >      -- { invariant of internedMuShape }
          -- > == numNestedMu internedMuShape
          InternedMu -> Node
internedMuShape InternedMu
mu
       | Bool
otherwise  ->
          RecNodeId -> Node -> Node -> Node
substFree (Int -> RecNodeId
RecInt (InternedMu -> Int
internedMuId InternedMu
mu)) Node
n' (InternedMu -> Node
internedMuBody InternedMu
mu)

matchMu Node
_otherwise = Maybe (Node -> Node)
forall a. Maybe a
Nothing

-- | Substitution
--
-- @substFree i n@ will replace all occurrences of @Rec (RecNodeId i)@ by @n@. We appeal to the uniqueness of node IDs
-- and assume that all occurrences of @i@ must be free (in other words, that any occurrences of 'Mu' will have a
-- /different/ identifier.
--
-- Postcondition:
--
-- > substFree i (Rec (RecNodeId i)) == id
substFree :: RecNodeId -> Node -> Node -> Node
substFree :: RecNodeId -> Node -> Node -> Node
substFree RecNodeId
old Node
new = Map RecNodeId Node -> Node -> Node
substFree' (RecNodeId -> Node -> Map RecNodeId Node
forall k a. k -> a -> Map k a
Map.singleton RecNodeId
old Node
new)

-- | Generalization of 'substFree' to multiple binders.
substFree' :: Map RecNodeId Node -> Node -> Node
substFree' :: Map RecNodeId Node -> Node -> Node
substFree' Map RecNodeId Node
env Node
node = case Node -> Template Node
template Node
node of
                        Template Map RecNodeId Node -> Node
f -> Map RecNodeId Node -> Node
f Map RecNodeId Node
env

------ Substitution internals

-- | The template of a something is that something with holes for as-yet unknown 'Id's
--
-- This datatype should satisfy two properties for 'template' to work correctly:
--
-- 1. Forcing the 'Template' to WHNF should not result in any recursive calls
--    (so that the recursion isn't totally unrolled before memoization can happen).
-- 2. But forcing the /function inside/ the 'Template' to WHNF /should/ result in all recursive calls to happen,
--    (/before/ the function is executed: executing the function should /not/ cause further calls to 'template').
--
-- The idea here is that a function returning a 'Template', the application of that 'Template' should not result in
-- further recursive calls to that function, so that any expensive computation done by that function is not repeated,
-- but is done independently of the environment (the 'Map') that we provide to the 'Template'. Put another way: the
-- function can be memoized independently of that environment. For substitution this may not matter very much, but for
-- other functions it could. Note however that the resulting 'Template' does build the graph on each invocation; this
-- may still be prohibitively expensive. See 'intersect' for an example of how we can avoid an environment altogether.
-- (This is not an option for substitution of course, where the environment is part of the API of the function.)
data Template a = Template (Map RecNodeId Node -> a)

-- | Commute @[]@ and 'Template'
--
-- Forces all elements in the list
sequenceTemplate :: [Template a] -> Template [a]
sequenceTemplate :: [Template a] -> Template [a]
sequenceTemplate = (Map RecNodeId Node -> [a]) -> Template [a]
forall a. (Map RecNodeId Node -> a) -> Template a
Template ((Map RecNodeId Node -> [a]) -> Template [a])
-> ([Template a] -> Map RecNodeId Node -> [a])
-> [Template a]
-> Template [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Map RecNodeId Node -> a]
-> [Template a] -> Map RecNodeId Node -> [a]
forall a.
[Map RecNodeId Node -> a]
-> [Template a] -> Map RecNodeId Node -> [a]
go []
  where
    go :: [Map RecNodeId Node -> a] -> [Template a] -> Map RecNodeId Node -> [a]
    go :: [Map RecNodeId Node -> a]
-> [Template a] -> Map RecNodeId Node -> [a]
go [Map RecNodeId Node -> a]
acc []               = \Map RecNodeId Node
env -> [a] -> [a]
forall a. [a] -> [a]
reverse (((Map RecNodeId Node -> a) -> a)
-> [Map RecNodeId Node -> a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map ((Map RecNodeId Node -> a) -> Map RecNodeId Node -> a
forall a b. (a -> b) -> a -> b
$ Map RecNodeId Node
env) [Map RecNodeId Node -> a]
acc)
    go [Map RecNodeId Node -> a]
acc (Template !Map RecNodeId Node -> a
f:[Template a]
fs) = [Map RecNodeId Node -> a]
-> [Template a] -> Map RecNodeId Node -> [a]
forall a.
[Map RecNodeId Node -> a]
-> [Template a] -> Map RecNodeId Node -> [a]
go (Map RecNodeId Node -> a
f(Map RecNodeId Node -> a)
-> [Map RecNodeId Node -> a] -> [Map RecNodeId Node -> a]
forall a. a -> [a] -> [a]
:[Map RecNodeId Node -> a]
acc) [Template a]
fs

-- | Extract the shape from a term
--
-- Somewhat serendipitously (or does this point to some deeper truth?) this also serves as a definition of substitution:
-- any free variables in the original node will become " holes " in the 'Template'.
--
-- We do not use the pattern synonyms here, because 'template' is used (through 'substFree') to /define/ those
-- pattern synonyms.
template :: Node -> Template Node
{-# NOINLINE template #-}
template :: Node -> Template Node
template = MemoCacheTag -> (Node -> Template Node) -> Node -> Template Node
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo (Text -> MemoCacheTag
NameTag Text
"template") Node -> Template Node
onNode
  where
    onNode :: Node -> Template Node
    onNode :: Node -> Template Node
onNode Node
n = (Map RecNodeId Node -> Node) -> Template Node
forall a. (Map RecNodeId Node -> a) -> Template a
Template ((Map RecNodeId Node -> Node) -> Template Node)
-> (Map RecNodeId Node -> Node) -> Template Node
forall a b. (a -> b) -> a -> b
$
        case Node
n of
          Node
EmptyNode         -> \Map RecNodeId Node
_ -> Node
EmptyNode
          InternedNode InternedNode
node -> case [Template Edge] -> Template [Edge]
forall a. [Template a] -> Template [a]
sequenceTemplate ([Template Edge] -> Template [Edge])
-> [Template Edge] -> Template [Edge]
forall a b. (a -> b) -> a -> b
$ (Edge -> Template Edge) -> [Edge] -> [Template Edge]
forall a b. (a -> b) -> [a] -> [b]
map Edge -> Template Edge
templateEdge (InternedNode -> [Edge]
internedNodeEdges InternedNode
node) of
                                      Template !Map RecNodeId Node -> [Edge]
f -> \Map RecNodeId Node
env -> [Edge] -> Node
mkNode (Map RecNodeId Node -> [Edge]
f Map RecNodeId Node
env)
          InternedMu InternedMu
mu     -> case Node -> Template Node
onNode (InternedMu -> Node
internedMuBody InternedMu
mu) of
                                      Template !Map RecNodeId Node -> Node
f -> \Map RecNodeId Node
env -> (Node -> Node) -> Node
createMu ((Node -> Node) -> Node) -> (Node -> Node) -> Node
forall a b. (a -> b) -> a -> b
$ \Node
r -> Map RecNodeId Node -> Node
f (RecNodeId -> Node -> Map RecNodeId Node -> Map RecNodeId Node
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Int -> RecNodeId
RecInt (InternedMu -> Int
internedMuId InternedMu
mu)) Node
r Map RecNodeId Node
env)
          Rec RecNodeId
i             -> \Map RecNodeId Node
env -> Node -> Maybe Node -> Node
forall a. a -> Maybe a -> a
fromMaybe Node
n (RecNodeId -> Map RecNodeId Node -> Maybe Node
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup RecNodeId
i Map RecNodeId Node
env)

-- | Internal auxiliary to 'template'
templateEdge :: Edge -> Template Edge
{-# NOINLINE templateEdge #-}
templateEdge :: Edge -> Template Edge
templateEdge = MemoCacheTag -> (Edge -> Template Edge) -> Edge -> Template Edge
forall a b.
(Eq a, Hashable a) =>
MemoCacheTag -> (a -> b) -> a -> b
memo (Text -> MemoCacheTag
NameTag Text
"templateEdge") Edge -> Template Edge
onEdge
  where
    onEdge :: Edge -> Template Edge
    onEdge :: Edge -> Template Edge
onEdge Edge
e =
        (Map RecNodeId Node -> Edge) -> Template Edge
forall a. (Map RecNodeId Node -> a) -> Template a
Template ((Map RecNodeId Node -> Edge) -> Template Edge)
-> (Map RecNodeId Node -> Edge) -> Template Edge
forall a b. (a -> b) -> a -> b
$ case [Template Node] -> Template [Node]
forall a. [Template a] -> Template [a]
sequenceTemplate ((Node -> Template Node) -> [Node] -> [Template Node]
forall a b. (a -> b) -> [a] -> [b]
map Node -> Template Node
template (Edge -> [Node]
edgeChildren Edge
e)) of
                  Template !Map RecNodeId Node -> [Node]
f -> Edge -> [Node] -> Edge
setChildren Edge
e ([Node] -> Edge)
-> (Map RecNodeId Node -> [Node]) -> Map RecNodeId Node -> Edge
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map RecNodeId Node -> [Node]
f