{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
----------------------------------------------------------------------
-- |
-- Module      :  Data.DecisionDiagram.BDD.Internal.Node
-- Copyright   :  (c) Masahiro Sakai 2021
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  unstable
-- Portability :  non-portable
--
----------------------------------------------------------------------
module Data.DecisionDiagram.BDD.Internal.Node
  (
  -- * Low level node type
    Node (Leaf, Branch)
  , nodeId

  , numNodes

  -- * Fold
  , fold
  , fold'
  , mkFold'Op

  -- * (Co)algebraic structure
  , Sig (..)

  -- * Graph
  , Graph
  , toGraph
  , toGraph'
  , foldGraph
  , foldGraphNodes
  ) where

import Control.Monad
import Control.Monad.ST
import Control.Monad.ST.Unsafe
import Data.Functor.Identity
import Data.Hashable
import qualified Data.HashTable.Class as H
import qualified Data.HashTable.ST.Cuckoo as C
import Data.Interned
import Data.IntMap.Lazy (IntMap)
import qualified Data.IntMap.Lazy as IntMap
import Data.STRef
import GHC.Generics
import GHC.Stack

-- ------------------------------------------------------------------------

-- | Hash-consed node types in BDD or ZDD
data Node = Node {-# UNPACK #-} !Id UNode
  deriving (Int -> Node -> ShowS
[Node] -> ShowS
Node -> String
(Int -> Node -> ShowS)
-> (Node -> String) -> ([Node] -> ShowS) -> Show Node
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Node] -> ShowS
$cshowList :: [Node] -> ShowS
show :: Node -> String
$cshow :: Node -> String
showsPrec :: Int -> Node -> ShowS
$cshowsPrec :: Int -> Node -> ShowS
Show)

instance Eq Node where
  Node Int
i UNode
_ == :: Node -> Node -> Bool
== Node Int
j UNode
_ = Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j

instance Hashable Node where
  hashWithSalt :: Int -> Node -> Int
hashWithSalt Int
s (Node Int
i UNode
_) = Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s Int
i

pattern T :: Node
pattern $bT :: Node
$mT :: forall r. Node -> (Void# -> r) -> (Void# -> r) -> r
T <- (unintern -> UT) where
  T = Uninterned Node -> Node
forall t. Interned t => Uninterned t -> t
intern Uninterned Node
UNode
UT

pattern F :: Node
pattern $bF :: Node
$mF :: forall r. Node -> (Void# -> r) -> (Void# -> r) -> r
F <- (unintern -> UF) where
  F = Uninterned Node -> Node
forall t. Interned t => Uninterned t -> t
intern Uninterned Node
UNode
UF

pattern Leaf :: Bool -> Node
pattern $bLeaf :: Bool -> Node
$mLeaf :: forall r. Node -> (Bool -> r) -> (Void# -> r) -> r
Leaf b <- (asBool -> Just b) where
  Leaf Bool
True = Node
T
  Leaf Bool
False = Node
F

asBool :: Node -> Maybe Bool
asBool :: Node -> Maybe Bool
asBool Node
T = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
asBool Node
F = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
asBool Node
_ = Maybe Bool
forall a. Maybe a
Nothing

pattern Branch :: Int -> Node -> Node -> Node
pattern $bBranch :: Int -> Node -> Node -> Node
$mBranch :: forall r. Node -> (Int -> Node -> Node -> r) -> (Void# -> r) -> r
Branch ind lo hi <- (unintern -> UBranch ind lo hi) where
  Branch Int
ind Node
lo Node
hi = Uninterned Node -> Node
forall t. Interned t => Uninterned t -> t
intern (Int -> Node -> Node -> UNode
UBranch Int
ind Node
lo Node
hi)

{-# COMPLETE T, F, Branch #-}
{-# COMPLETE Leaf, Branch #-}

data UNode
  = UT
  | UF
  | UBranch {-# UNPACK #-} !Int Node Node
  deriving (Int -> UNode -> ShowS
[UNode] -> ShowS
UNode -> String
(Int -> UNode -> ShowS)
-> (UNode -> String) -> ([UNode] -> ShowS) -> Show UNode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UNode] -> ShowS
$cshowList :: [UNode] -> ShowS
show :: UNode -> String
$cshow :: UNode -> String
showsPrec :: Int -> UNode -> ShowS
$cshowsPrec :: Int -> UNode -> ShowS
Show)

instance Interned Node where
  type Uninterned Node = UNode
  data Description Node
    = DT
    | DF
    | DBranch {-# UNPACK #-} !Int {-# UNPACK #-} !Id {-# UNPACK #-} !Id
    deriving (Description Node -> Description Node -> Bool
(Description Node -> Description Node -> Bool)
-> (Description Node -> Description Node -> Bool)
-> Eq (Description Node)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Description Node -> Description Node -> Bool
$c/= :: Description Node -> Description Node -> Bool
== :: Description Node -> Description Node -> Bool
$c== :: Description Node -> Description Node -> Bool
Eq, (forall x. Description Node -> Rep (Description Node) x)
-> (forall x. Rep (Description Node) x -> Description Node)
-> Generic (Description Node)
forall x. Rep (Description Node) x -> Description Node
forall x. Description Node -> Rep (Description Node) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep (Description Node) x -> Description Node
$cfrom :: forall x. Description Node -> Rep (Description Node) x
Generic)
  describe :: Uninterned Node -> Description Node
describe Uninterned Node
UT = Description Node
DT
  describe Uninterned Node
UF = Description Node
DF
  describe (UBranch x (Node i _) (Node j _)) = Int -> Int -> Int -> Description Node
DBranch Int
x Int
i Int
j
  identify :: Int -> Uninterned Node -> Node
identify = Int -> Uninterned Node -> Node
Int -> UNode -> Node
Node
  cache :: Cache Node
cache = Cache Node
nodeCache

instance Hashable (Description Node)

instance Uninternable Node where
  unintern :: Node -> Uninterned Node
unintern (Node Int
_ UNode
uformula) = Uninterned Node
UNode
uformula

nodeCache :: Cache Node
nodeCache :: Cache Node
nodeCache = Cache Node
forall t. Interned t => Cache t
mkCache
{-# NOINLINE nodeCache #-}

nodeId :: Node -> Id
nodeId :: Node -> Int
nodeId (Node Int
id_ UNode
_) = Int
id_

-- ------------------------------------------------------------------------

defaultTableSize :: Int
defaultTableSize :: Int
defaultTableSize = Int
256

-- | Counts the number of nodes when viewed as a rooted directed acyclic graph
numNodes :: Node -> Int
numNodes :: Node -> Int
numNodes Node
node0 = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ do
  HashTable s Node ()
h <- Int -> ST s (HashTable s Node ())
forall s k v. Int -> ST s (HashTable s k v)
C.newSized Int
defaultTableSize
  let f :: Node -> ST s ()
f Node
node = do
        Maybe ()
m <- HashTable s Node () -> Node -> ST s (Maybe ())
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> ST s (Maybe v)
H.lookup HashTable s Node ()
h Node
node
        case Maybe ()
m of
          Just ()
_ -> () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Maybe ()
Nothing -> do
            HashTable s Node () -> Node -> () -> ST s ()
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> v -> ST s ()
H.insert HashTable s Node ()
h Node
node ()
            case Node
node of
              Branch Int
_ Node
lo Node
hi -> Node -> ST s ()
f Node
lo ST s () -> ST s () -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Node -> ST s ()
f Node
hi
              Node
_ -> () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  Node -> ST s ()
f Node
node0
  ([(Node, ())] -> Int) -> ST s [(Node, ())] -> ST s Int
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM [(Node, ())] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ST s [(Node, ())] -> ST s Int) -> ST s [(Node, ())] -> ST s Int
forall a b. (a -> b) -> a -> b
$ HashTable s Node () -> ST s [(Node, ())]
forall (h :: * -> * -> * -> *) s k v.
HashTable h =>
h s k v -> ST s [(k, v)]
H.toList HashTable s Node ()
h

-- ------------------------------------------------------------------------

-- | Signature functor of binary decision trees, BDD, and ZDD.
data Sig a
  = SLeaf !Bool
  | SBranch !Int a a
  deriving (Sig a -> Sig a -> Bool
(Sig a -> Sig a -> Bool) -> (Sig a -> Sig a -> Bool) -> Eq (Sig a)
forall a. Eq a => Sig a -> Sig a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sig a -> Sig a -> Bool
$c/= :: forall a. Eq a => Sig a -> Sig a -> Bool
== :: Sig a -> Sig a -> Bool
$c== :: forall a. Eq a => Sig a -> Sig a -> Bool
Eq, Eq (Sig a)
Eq (Sig a)
-> (Sig a -> Sig a -> Ordering)
-> (Sig a -> Sig a -> Bool)
-> (Sig a -> Sig a -> Bool)
-> (Sig a -> Sig a -> Bool)
-> (Sig a -> Sig a -> Bool)
-> (Sig a -> Sig a -> Sig a)
-> (Sig a -> Sig a -> Sig a)
-> Ord (Sig a)
Sig a -> Sig a -> Bool
Sig a -> Sig a -> Ordering
Sig a -> Sig a -> Sig a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Sig a)
forall a. Ord a => Sig a -> Sig a -> Bool
forall a. Ord a => Sig a -> Sig a -> Ordering
forall a. Ord a => Sig a -> Sig a -> Sig a
min :: Sig a -> Sig a -> Sig a
$cmin :: forall a. Ord a => Sig a -> Sig a -> Sig a
max :: Sig a -> Sig a -> Sig a
$cmax :: forall a. Ord a => Sig a -> Sig a -> Sig a
>= :: Sig a -> Sig a -> Bool
$c>= :: forall a. Ord a => Sig a -> Sig a -> Bool
> :: Sig a -> Sig a -> Bool
$c> :: forall a. Ord a => Sig a -> Sig a -> Bool
<= :: Sig a -> Sig a -> Bool
$c<= :: forall a. Ord a => Sig a -> Sig a -> Bool
< :: Sig a -> Sig a -> Bool
$c< :: forall a. Ord a => Sig a -> Sig a -> Bool
compare :: Sig a -> Sig a -> Ordering
$ccompare :: forall a. Ord a => Sig a -> Sig a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Sig a)
Ord, Int -> Sig a -> ShowS
[Sig a] -> ShowS
Sig a -> String
(Int -> Sig a -> ShowS)
-> (Sig a -> String) -> ([Sig a] -> ShowS) -> Show (Sig a)
forall a. Show a => Int -> Sig a -> ShowS
forall a. Show a => [Sig a] -> ShowS
forall a. Show a => Sig a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Sig a] -> ShowS
$cshowList :: forall a. Show a => [Sig a] -> ShowS
show :: Sig a -> String
$cshow :: forall a. Show a => Sig a -> String
showsPrec :: Int -> Sig a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Sig a -> ShowS
Show, ReadPrec [Sig a]
ReadPrec (Sig a)
Int -> ReadS (Sig a)
ReadS [Sig a]
(Int -> ReadS (Sig a))
-> ReadS [Sig a]
-> ReadPrec (Sig a)
-> ReadPrec [Sig a]
-> Read (Sig a)
forall a. Read a => ReadPrec [Sig a]
forall a. Read a => ReadPrec (Sig a)
forall a. Read a => Int -> ReadS (Sig a)
forall a. Read a => ReadS [Sig a]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Sig a]
$creadListPrec :: forall a. Read a => ReadPrec [Sig a]
readPrec :: ReadPrec (Sig a)
$creadPrec :: forall a. Read a => ReadPrec (Sig a)
readList :: ReadS [Sig a]
$creadList :: forall a. Read a => ReadS [Sig a]
readsPrec :: Int -> ReadS (Sig a)
$creadsPrec :: forall a. Read a => Int -> ReadS (Sig a)
Read, (forall x. Sig a -> Rep (Sig a) x)
-> (forall x. Rep (Sig a) x -> Sig a) -> Generic (Sig a)
forall x. Rep (Sig a) x -> Sig a
forall x. Sig a -> Rep (Sig a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Sig a) x -> Sig a
forall a x. Sig a -> Rep (Sig a) x
$cto :: forall a x. Rep (Sig a) x -> Sig a
$cfrom :: forall a x. Sig a -> Rep (Sig a) x
Generic, a -> Sig b -> Sig a
(a -> b) -> Sig a -> Sig b
(forall a b. (a -> b) -> Sig a -> Sig b)
-> (forall a b. a -> Sig b -> Sig a) -> Functor Sig
forall a b. a -> Sig b -> Sig a
forall a b. (a -> b) -> Sig a -> Sig b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Sig b -> Sig a
$c<$ :: forall a b. a -> Sig b -> Sig a
fmap :: (a -> b) -> Sig a -> Sig b
$cfmap :: forall a b. (a -> b) -> Sig a -> Sig b
Functor, Sig a -> Bool
(a -> m) -> Sig a -> m
(a -> b -> b) -> b -> Sig a -> b
(forall m. Monoid m => Sig m -> m)
-> (forall m a. Monoid m => (a -> m) -> Sig a -> m)
-> (forall m a. Monoid m => (a -> m) -> Sig a -> m)
-> (forall a b. (a -> b -> b) -> b -> Sig a -> b)
-> (forall a b. (a -> b -> b) -> b -> Sig a -> b)
-> (forall b a. (b -> a -> b) -> b -> Sig a -> b)
-> (forall b a. (b -> a -> b) -> b -> Sig a -> b)
-> (forall a. (a -> a -> a) -> Sig a -> a)
-> (forall a. (a -> a -> a) -> Sig a -> a)
-> (forall a. Sig a -> [a])
-> (forall a. Sig a -> Bool)
-> (forall a. Sig a -> Int)
-> (forall a. Eq a => a -> Sig a -> Bool)
-> (forall a. Ord a => Sig a -> a)
-> (forall a. Ord a => Sig a -> a)
-> (forall a. Num a => Sig a -> a)
-> (forall a. Num a => Sig a -> a)
-> Foldable Sig
forall a. Eq a => a -> Sig a -> Bool
forall a. Num a => Sig a -> a
forall a. Ord a => Sig a -> a
forall m. Monoid m => Sig m -> m
forall a. Sig a -> Bool
forall a. Sig a -> Int
forall a. Sig a -> [a]
forall a. (a -> a -> a) -> Sig a -> a
forall m a. Monoid m => (a -> m) -> Sig a -> m
forall b a. (b -> a -> b) -> b -> Sig a -> b
forall a b. (a -> b -> b) -> b -> Sig a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: Sig a -> a
$cproduct :: forall a. Num a => Sig a -> a
sum :: Sig a -> a
$csum :: forall a. Num a => Sig a -> a
minimum :: Sig a -> a
$cminimum :: forall a. Ord a => Sig a -> a
maximum :: Sig a -> a
$cmaximum :: forall a. Ord a => Sig a -> a
elem :: a -> Sig a -> Bool
$celem :: forall a. Eq a => a -> Sig a -> Bool
length :: Sig a -> Int
$clength :: forall a. Sig a -> Int
null :: Sig a -> Bool
$cnull :: forall a. Sig a -> Bool
toList :: Sig a -> [a]
$ctoList :: forall a. Sig a -> [a]
foldl1 :: (a -> a -> a) -> Sig a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> Sig a -> a
foldr1 :: (a -> a -> a) -> Sig a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> Sig a -> a
foldl' :: (b -> a -> b) -> b -> Sig a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> Sig a -> b
foldl :: (b -> a -> b) -> b -> Sig a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> Sig a -> b
foldr' :: (a -> b -> b) -> b -> Sig a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> Sig a -> b
foldr :: (a -> b -> b) -> b -> Sig a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> Sig a -> b
foldMap' :: (a -> m) -> Sig a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> Sig a -> m
foldMap :: (a -> m) -> Sig a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> Sig a -> m
fold :: Sig m -> m
$cfold :: forall m. Monoid m => Sig m -> m
Foldable, Functor Sig
Foldable Sig
Functor Sig
-> Foldable Sig
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> Sig a -> f (Sig b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    Sig (f a) -> f (Sig a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> Sig a -> m (Sig b))
-> (forall (m :: * -> *) a. Monad m => Sig (m a) -> m (Sig a))
-> Traversable Sig
(a -> f b) -> Sig a -> f (Sig b)
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => Sig (m a) -> m (Sig a)
forall (f :: * -> *) a. Applicative f => Sig (f a) -> f (Sig a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Sig a -> m (Sig b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Sig a -> f (Sig b)
sequence :: Sig (m a) -> m (Sig a)
$csequence :: forall (m :: * -> *) a. Monad m => Sig (m a) -> m (Sig a)
mapM :: (a -> m b) -> Sig a -> m (Sig b)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Sig a -> m (Sig b)
sequenceA :: Sig (f a) -> f (Sig a)
$csequenceA :: forall (f :: * -> *) a. Applicative f => Sig (f a) -> f (Sig a)
traverse :: (a -> f b) -> Sig a -> f (Sig b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Sig a -> f (Sig b)
$cp2Traversable :: Foldable Sig
$cp1Traversable :: Functor Sig
Traversable)

instance Hashable a => Hashable (Sig a)

-- ------------------------------------------------------------------------

-- | Fold over the graph structure of Node.
--
-- It takes two functions that substitute 'Branch'  and 'Leaf' respectively.
--
-- Note that its type is isomorphic to @('Sig' a -> a) -> 'Node' -> a@.
fold :: (Int -> a -> a -> a) -> (Bool -> a) -> Node -> a
fold :: (Int -> a -> a -> a) -> (Bool -> a) -> Node -> a
fold Int -> a -> a -> a
br Bool -> a
lf Node
bdd = (forall s. ST s a) -> a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s a) -> a) -> (forall s. ST s a) -> a
forall a b. (a -> b) -> a -> b
$ do
  HashTable s Node a
h <- Int -> ST s (HashTable s Node a)
forall s k v. Int -> ST s (HashTable s k v)
C.newSized Int
defaultTableSize
  let f :: Node -> ST s a
f (Leaf Bool
b) = a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> a
lf Bool
b)
      f p :: Node
p@(Branch Int
top Node
lo Node
hi) = do
        Maybe a
m <- HashTable s Node a -> Node -> ST s (Maybe a)
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> ST s (Maybe v)
H.lookup HashTable s Node a
h Node
p
        case Maybe a
m of
          Just a
ret -> a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
ret
          Maybe a
Nothing -> do
            a
r0 <- ST s a -> ST s a
forall s a. ST s a -> ST s a
unsafeInterleaveST (ST s a -> ST s a) -> ST s a -> ST s a
forall a b. (a -> b) -> a -> b
$ Node -> ST s a
f Node
lo
            a
r1 <- ST s a -> ST s a
forall s a. ST s a -> ST s a
unsafeInterleaveST (ST s a -> ST s a) -> ST s a -> ST s a
forall a b. (a -> b) -> a -> b
$ Node -> ST s a
f Node
hi
            let ret :: a
ret = Int -> a -> a -> a
br Int
top a
r0 a
r1
            HashTable s Node a -> Node -> a -> ST s ()
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> v -> ST s ()
H.insert HashTable s Node a
h Node
p a
ret  -- Note that H.insert is value-strict
            a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
ret
  Node -> ST s a
f Node
bdd

-- | Strict version of 'fold'
fold' :: (Int -> a -> a -> a) -> (Bool -> a) -> Node -> a
fold' :: (Int -> a -> a -> a) -> (Bool -> a) -> Node -> a
fold' Int -> a -> a -> a
br Bool -> a
lf Node
bdd = (forall s. ST s a) -> a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s a) -> a) -> (forall s. ST s a) -> a
forall a b. (a -> b) -> a -> b
$ do
  Node -> ST s a
op <- (Int -> a -> a -> a) -> (Bool -> a) -> ST s (Node -> ST s a)
forall a s.
(Int -> a -> a -> a) -> (Bool -> a) -> ST s (Node -> ST s a)
mkFold'Op Int -> a -> a -> a
br Bool -> a
lf
  Node -> ST s a
op Node
bdd

mkFold'Op :: (Int -> a -> a -> a) -> (Bool -> a) -> ST s (Node -> ST s a)
mkFold'Op :: (Int -> a -> a -> a) -> (Bool -> a) -> ST s (Node -> ST s a)
mkFold'Op Int -> a -> a -> a
br Bool -> a
lf = do
  HashTable s Node a
h <- Int -> ST s (HashTable s Node a)
forall s k v. Int -> ST s (HashTable s k v)
C.newSized Int
defaultTableSize
  let f :: Node -> ST s a
f (Leaf Bool
b) = a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> ST s a) -> a -> ST s a
forall a b. (a -> b) -> a -> b
$! Bool -> a
lf Bool
b
      f p :: Node
p@(Branch Int
top Node
lo Node
hi) = do
        Maybe a
m <- HashTable s Node a -> Node -> ST s (Maybe a)
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> ST s (Maybe v)
H.lookup HashTable s Node a
h Node
p
        case Maybe a
m of
          Just a
ret -> a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
ret
          Maybe a
Nothing -> do
            a
r0 <- Node -> ST s a
f Node
lo
            a
r1 <- Node -> ST s a
f Node
hi
            let ret :: a
ret = Int -> a -> a -> a
br Int
top a
r0 a
r1
            HashTable s Node a -> Node -> a -> ST s ()
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> v -> ST s ()
H.insert HashTable s Node a
h Node
p a
ret  -- Note that H.insert is value-strict
            a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
ret
  (Node -> ST s a) -> ST s (Node -> ST s a)
forall (m :: * -> *) a. Monad m => a -> m a
return Node -> ST s a
f

-- ------------------------------------------------------------------------

-- | Graph where nodes are decorated using a functor @f@.
--
-- The occurrences of the parameter of @f@ represent out-going edges.
type Graph f = IntMap (f Int)

-- | Convert a node into a pointed graph
--
-- Nodes @0@ and @1@ are reserved for @SLeaf False@ and @SLeaf True@ even if
-- they are not actually used. Therefore the result may be larger than
-- 'numNodes' if the leaf nodes are not used.
toGraph :: Node -> (Graph Sig, Int)
toGraph :: Node -> (Graph Sig, Int)
toGraph Node
bdd =
  case Identity Node -> (Graph Sig, Identity Int)
forall (t :: * -> *). Traversable t => t Node -> (Graph Sig, t Int)
toGraph' (Node -> Identity Node
forall a. a -> Identity a
Identity Node
bdd) of
    (Graph Sig
g, Identity Int
v) -> (Graph Sig
g, Int
v)

-- | Convert multiple nodes into a graph
toGraph' :: Traversable t => t Node -> (Graph Sig, t Int)
toGraph' :: t Node -> (Graph Sig, t Int)
toGraph' t Node
bs = (forall s. ST s (Graph Sig, t Int)) -> (Graph Sig, t Int)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Graph Sig, t Int)) -> (Graph Sig, t Int))
-> (forall s. ST s (Graph Sig, t Int)) -> (Graph Sig, t Int)
forall a b. (a -> b) -> a -> b
$ do
  HashTable s Node Int
h <- Int -> ST s (HashTable s Node Int)
forall s k v. Int -> ST s (HashTable s k v)
C.newSized Int
defaultTableSize
  HashTable s Node Int -> Node -> Int -> ST s ()
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> v -> ST s ()
H.insert HashTable s Node Int
h (Bool -> Node
Leaf Bool
False) Int
0
  HashTable s Node Int -> Node -> Int -> ST s ()
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> v -> ST s ()
H.insert HashTable s Node Int
h (Bool -> Node
Leaf Bool
True) Int
1
  STRef s Int
counter <- Int -> ST s (STRef s Int)
forall a s. a -> ST s (STRef s a)
newSTRef Int
2
  STRef s (Graph Sig)
ref <- Graph Sig -> ST s (STRef s (Graph Sig))
forall a s. a -> ST s (STRef s a)
newSTRef (Graph Sig -> ST s (STRef s (Graph Sig)))
-> Graph Sig -> ST s (STRef s (Graph Sig))
forall a b. (a -> b) -> a -> b
$ [(Int, Sig Int)] -> Graph Sig
forall a. [(Int, a)] -> IntMap a
IntMap.fromList [(Int
0, Bool -> Sig Int
forall a. Bool -> Sig a
SLeaf Bool
False), (Int
1, Bool -> Sig Int
forall a. Bool -> Sig a
SLeaf Bool
True)]

  let f :: Node -> ST s Int
f (Leaf Bool
False) = Int -> ST s Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
      f (Leaf Bool
True) = Int -> ST s Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
1
      f p :: Node
p@(Branch Int
x Node
lo Node
hi) = do
        Maybe Int
m <- HashTable s Node Int -> Node -> ST s (Maybe Int)
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> ST s (Maybe v)
H.lookup HashTable s Node Int
h Node
p
        case Maybe Int
m of
          Just Int
ret -> Int -> ST s Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
ret
          Maybe Int
Nothing -> do
            Int
r0 <- Node -> ST s Int
f Node
lo
            Int
r1 <- Node -> ST s Int
f Node
hi
            Int
n <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
counter
            STRef s Int -> Int -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Int
counter (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$! Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1
            HashTable s Node Int -> Node -> Int -> ST s ()
forall (h :: * -> * -> * -> *) k s v.
(HashTable h, Eq k, Hashable k) =>
h s k v -> k -> v -> ST s ()
H.insert HashTable s Node Int
h Node
p Int
n
            STRef s (Graph Sig) -> (Graph Sig -> Graph Sig) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s (Graph Sig)
ref (Int -> Sig Int -> Graph Sig -> Graph Sig
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
n (Int -> Int -> Int -> Sig Int
forall a. Int -> a -> a -> Sig a
SBranch Int
x Int
r0 Int
r1))
            Int -> ST s Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n

  t Int
vs <- (Node -> ST s Int) -> t Node -> ST s (t Int)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Node -> ST s Int
f t Node
bs
  Graph Sig
g <- STRef s (Graph Sig) -> ST s (Graph Sig)
forall s a. STRef s a -> ST s a
readSTRef STRef s (Graph Sig)
ref
  (Graph Sig, t Int) -> ST s (Graph Sig, t Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Graph Sig
g, t Int
vs)

-- | Fold over pointed graph
foldGraph :: (Functor f, HasCallStack) => (f a -> a) -> (Graph f, Int) -> a
foldGraph :: (f a -> a) -> (Graph f, Int) -> a
foldGraph f a -> a
f (Graph f
g, Int
v) =
  case Int -> IntMap a -> Maybe a
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
v ((f a -> a) -> Graph f -> IntMap a
forall (f :: * -> *) a.
(Functor f, HasCallStack) =>
(f a -> a) -> Graph f -> IntMap a
foldGraphNodes f a -> a
f Graph f
g) of
    Just a
x -> a
x
    Maybe a
Nothing -> String -> a
forall a. HasCallStack => String -> a
error (String
"foldGraphNodes: invalid node id " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
v)

-- | Fold over graph nodes
foldGraphNodes :: (Functor f, HasCallStack) => (f a -> a) -> Graph f -> IntMap a
foldGraphNodes :: (f a -> a) -> Graph f -> IntMap a
foldGraphNodes f a -> a
f Graph f
m = IntMap a
ret
  where
    ret :: IntMap a
ret = (f Int -> a) -> Graph f -> IntMap a
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (f a -> a
f (f a -> a) -> (f Int -> f a) -> f Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> a) -> f Int -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> a
h) Graph f
m
    h :: Int -> a
h Int
v =
      case Int -> IntMap a -> Maybe a
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
v IntMap a
ret of
        Just a
x -> a
x
        Maybe a
Nothing -> String -> a
forall a. HasCallStack => String -> a
error (String
"foldGraphNodes: invalid node id " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
v)

-- ------------------------------------------------------------------------