-- |
-- SPDX-License-Identifier: BSD-3-Clause
-- Description: Compress representation of traversable
--
-- Useful for compressing the representation of a
-- structure that has many repeating elements
-- for transmission (e.g. over the network).
module Swarm.Util.OccurrenceEncoder (
  runEncoder,
) where

import Control.Monad.Trans.State
import Data.List (sortOn)
import Data.Map (Map)
import Data.Map qualified as M

type OccurrenceEncoder a = State (Encoder a)

newtype Encoder a = Encoder (Map a Int)

-- |
-- Given a data structure that may have many repeating "complex" elements,
-- will store the "complex" element representation
-- in an array so that the structure's elements can be replaced
-- with simple indices into that array.
--
-- The first encountered element is assigned index 0, and the next
-- novel element encountered gets index 1, and so on.
runEncoder ::
  (Traversable t, Ord b) =>
  t b ->
  (t Int, [b])
runEncoder :: forall (t :: * -> *) b.
(Traversable t, Ord b) =>
t b -> (t Int, [b])
runEncoder t b
structure =
  Encoder b -> [b]
forall a. Encoder a -> [a]
getIndices (Encoder b -> [b]) -> (t Int, Encoder b) -> (t Int, [b])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State (Encoder b) (t Int) -> Encoder b -> (t Int, Encoder b)
forall s a. State s a -> s -> (a, s)
runState ((b -> StateT (Encoder b) Identity Int)
-> t b -> State (Encoder b) (t Int)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> t a -> m (t b)
mapM b -> StateT (Encoder b) Identity Int
forall a. Ord a => a -> OccurrenceEncoder a Int
encodeOccurrence t b
structure) Encoder b
forall a. Ord a => Encoder a
emptyEncoder

emptyEncoder :: Ord a => Encoder a
emptyEncoder :: forall a. Ord a => Encoder a
emptyEncoder = Map a Int -> Encoder a
forall a. Map a Int -> Encoder a
Encoder Map a Int
forall a. Monoid a => a
mempty

-- | Map indices are guaranteed to be contiguous
-- from @[0..N]@, so we may convert to a list
-- with no loss of information.
getIndices :: Encoder a -> [a]
getIndices :: forall a. Encoder a -> [a]
getIndices (Encoder Map a Int
m) = ((a, Int) -> a) -> [(a, Int)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, Int) -> a
forall a b. (a, b) -> a
fst ([(a, Int)] -> [a]) -> [(a, Int)] -> [a]
forall a b. (a -> b) -> a -> b
$ ((a, Int) -> Int) -> [(a, Int)] -> [(a, Int)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (a, Int) -> Int
forall a b. (a, b) -> b
snd ([(a, Int)] -> [(a, Int)]) -> [(a, Int)] -> [(a, Int)]
forall a b. (a -> b) -> a -> b
$ Map a Int -> [(a, Int)]
forall k a. Map k a -> [(k, a)]
M.toList Map a Int
m

-- | Translate each the first occurrence in the structure
-- to a new integer as it is encountered.
-- Subsequent encounters re-use the allocated integer.
encodeOccurrence :: Ord a => a -> OccurrenceEncoder a Int
encodeOccurrence :: forall a. Ord a => a -> OccurrenceEncoder a Int
encodeOccurrence a
c = do
  Encoder Map a Int
currentMap <- StateT (Encoder a) Identity (Encoder a)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  OccurrenceEncoder a Int
-> (Int -> OccurrenceEncoder a Int)
-> Maybe Int
-> OccurrenceEncoder a Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Map a Int -> OccurrenceEncoder a Int
forall {m :: * -> *}.
Monad m =>
Map a Int -> StateT (Encoder a) m Int
cacheNewIndex Map a Int
currentMap) Int -> OccurrenceEncoder a Int
forall a. a -> StateT (Encoder a) Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Int -> OccurrenceEncoder a Int)
-> Maybe Int -> OccurrenceEncoder a Int
forall a b. (a -> b) -> a -> b
$
    a -> Map a Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup a
c Map a Int
currentMap
 where
  cacheNewIndex :: Map a Int -> StateT (Encoder a) m Int
cacheNewIndex Map a Int
currentMap = do
    Encoder a -> StateT (Encoder a) m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Encoder a -> StateT (Encoder a) m ())
-> Encoder a -> StateT (Encoder a) m ()
forall a b. (a -> b) -> a -> b
$ Map a Int -> Encoder a
forall a. Map a Int -> Encoder a
Encoder (Map a Int -> Encoder a) -> Map a Int -> Encoder a
forall a b. (a -> b) -> a -> b
$ a -> Int -> Map a Int -> Map a Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert a
c Int
newIdx Map a Int
currentMap
    Int -> StateT (Encoder a) m Int
forall a. a -> StateT (Encoder a) m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
newIdx
   where
    newIdx :: Int
newIdx = Map a Int -> Int
forall k a. Map k a -> Int
M.size Map a Int
currentMap