{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE BangPatterns #-}
{-|

Union-find-like data structure that defines equivalence classes of e-class ids.

-}
module Data.Equality.Graph.ReprUnionFind
  ( ReprUnionFind
  , emptyUF
  , makeNewSet
  , unionSets
  , findRepr
  ) where

import Data.Equality.Graph.Classes.Id

#if __GLASGOW_HASKELL__ >= 902

import qualified Data.Equality.Utils.IntToIntMap as IIM
import GHC.Exts ((+#), Int(..), Int#)

type RUFSize = Int#

-- | A union find for equivalence classes of e-class ids.
data ReprUnionFind = RUF IIM.IntToIntMap -- ^ Map every id to either 0# (meaning its the representative) or to another int# (meaning its represented by some other id)
                         RUFSize         -- ^ Counter for new ids

                         -- !(IM.IntMap [ClassId]) -- ^ Mapping from an id to all its children: This is used for "rebuilding" (compress all paths) when merging. Its a hashcons?
                         -- [ClassId] -- ^ Ids that can be safely deleted after the e-graph is rebuilt
#else

import qualified Data.IntMap.Internal as IIM (IntMap(..))
import qualified Data.IntMap.Strict as IIM

-- | A union find for equivalence classes of e-class ids.
data ReprUnionFind = RUF (IIM.IntMap Int)    -- ^ Map every id to either 0# (meaning its the representative) or to another int# (meaning its represented by some other id)
                         {-# UNPACK #-} !Int -- ^ Counter for new ids

#endif

-- Note that there's no value associated with identifier, so this union find
-- serves only to find the representative of an e-class id

instance Show ReprUnionFind where
  show :: ReprUnionFind -> String
show (RUF IntToIntMap
_ RUFSize
_) = String
"Warning: Incomplete show: ReprUnionFind"

-- | An @id@ can be represented by another @id@ or be canonical, meaning it
-- represents itself.
--
-- @(x, Represented y)@ would mean x is represented by y
-- @(x, Canonical)@ would mean x is canonical -- represents itself
newtype Repr
  = Represented { Repr -> Int
unRepr :: ClassId } -- ^ @Represented x@ is represented by @x@
--   | Canonical -- ^ @Canonical x@ is the canonical representation, meaning @find(x) == x@
  deriving Int -> Repr -> ShowS
[Repr] -> ShowS
Repr -> String
(Int -> Repr -> ShowS)
-> (Repr -> String) -> ([Repr] -> ShowS) -> Show Repr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Repr -> ShowS
showsPrec :: Int -> Repr -> ShowS
$cshow :: Repr -> String
show :: Repr -> String
$cshowList :: [Repr] -> ShowS
showList :: [Repr] -> ShowS
Show

-- | The empty 'ReprUnionFind'.
emptyUF :: ReprUnionFind
-- TODO: If I can make an instance of 'ReprUnionFind' for Monoid(?), this is 'mempty'
emptyUF :: ReprUnionFind
emptyUF = IntToIntMap -> RUFSize -> ReprUnionFind
RUF IntToIntMap
IIM.Nil
#if __GLASGOW_HASKELL__ >= 902
              RUFSize
1# -- Must start with 1# since 0# means "Canonical"
#else
              1
#endif

-- | Create a new e-class id in the given 'ReprUnionFind'.
makeNewSet :: ReprUnionFind
           -> (ClassId, ReprUnionFind) -- ^ Newly created e-class id and updated 'ReprUnionFind'
#if __GLASGOW_HASKELL__ >= 902
makeNewSet :: ReprUnionFind -> (Int, ReprUnionFind)
makeNewSet (RUF IntToIntMap
im RUFSize
si) = ((RUFSize -> Int
I# RUFSize
si), IntToIntMap -> RUFSize -> ReprUnionFind
RUF (RUFSize -> RUFSize -> IntToIntMap -> IntToIntMap
IIM.insert RUFSize
si RUFSize
0# IntToIntMap
im) ((RUFSize
si RUFSize -> RUFSize -> RUFSize
+# RUFSize
1#)))
#else
makeNewSet (RUF im si) = (si, RUF (IIM.insert si 0 im) (si + 1))
#endif
{-# SCC makeNewSet #-}

-- | Union operation of the union find.
--
-- Given two leader ids, unions the two eclasses making @a@ the leader, that
-- is, @b@ is now represented by @a@
unionSets :: ClassId                  -- ^ E-class id @a@
          -> ClassId                  -- ^ E-class id @b@
          -> ReprUnionFind            -- ^ Union-find containing @a@ and @b@
          -> (ClassId, ReprUnionFind) -- ^ The new leader (always @a@) and the updated union-find
#if __GLASGOW_HASKELL__ >= 902
unionSets :: Int -> Int -> ReprUnionFind -> (Int, ReprUnionFind)
unionSets a :: Int
a@(I# RUFSize
a#) (I# RUFSize
b#) (RUF IntToIntMap
im RUFSize
si) = (Int
a, IntToIntMap -> RUFSize -> ReprUnionFind
RUF (RUFSize -> RUFSize -> IntToIntMap -> IntToIntMap
IIM.insert RUFSize
b# RUFSize
a# IntToIntMap
im) RUFSize
si)
#else
unionSets a b (RUF im si) = (a, RUF (IIM.insert b a im) si)
#endif
  -- where
    -- represented_by_b = hc IM.! b
    -- -- Overwrite previous id of b (which should be 0#) with new representative (a)
    -- -- AND "rebuild" all nodes represented by b by making them represented directly by a
    -- new_im = {-# SCC "rebuild_im" #-} IIM.unliftedFoldr (\(I# x) -> IIM.insert x a#) (IIM.insert b# a# im) represented_by_b
    -- new_hc = {-# SCC "adjust_hc" #-} IM.adjust ((b:) . (represented_by_b <>)) a (IM.delete b hc)
{-# SCC unionSets #-}

-- | Find the canonical representation of an e-class id
findRepr :: ClassId -> ReprUnionFind
         -> ClassId -- ^ The found canonical representation
#if __GLASGOW_HASKELL__ >= 902
findRepr :: Int -> ReprUnionFind -> Int
findRepr v :: Int
v@(I# RUFSize
v#) (RUF IntToIntMap
m RUFSize
s) =
  case {-# SCC "findRepr_TAKE" #-} IntToIntMap
m IntToIntMap -> RUFSize -> RUFSize
IIM.! RUFSize
v# of
    RUFSize
0# -> Int
v
    RUFSize
x  -> Int -> ReprUnionFind -> Int
findRepr (RUFSize -> Int
I# RUFSize
x) (IntToIntMap -> RUFSize -> ReprUnionFind
RUF IntToIntMap
m RUFSize
s)
#else
findRepr v (RUF m s) =
  case {-# SCC "findRepr_TAKE" #-} m IIM.! v of
    0 -> v
    x -> findRepr x (RUF m s)
#endif

-- ROMES:TODO: Path compression in immutable data structure? Is it worth
-- the copy + threading?
--
-- ANSWER: According to my tests, findRepr is always quite shallow, going only
-- (from what I saw) until, at max, depth 3!
--
-- When using the ad-hoc path compression in `unionSets`, the depth of
-- recursion never even goes above 1!
{-# SCC findRepr #-}


-- {-# RULES
--    "union/find" forall a b x im. findRepr (I# b) (RUF (IIM.insert b a im) x) = I# a
--   #-}

-- -- | Delete nodes that have been merged after e-graph has been rebuilt
-- rebuildUF :: ReprUnionFind -> ReprUnionFind
-- rebuildUF (RUF m' a b dl) = RUF (IIM.unliftedFoldr (\(I# x) -> IIM.delete x) m' dl) a b mempty