-- | A queue where entries can be added in batches and stored compactly.
{-# LANGUAGE TypeFamilies, RecordWildCards, FlexibleContexts, ScopedTypeVariables #-}
module Data.BatchedQueue(
  Queue, Batch(..), StandardBatch, unbatch, empty, insert, removeMin, removeMinFilter, mapMaybe, toBatches, toList, size) where

import qualified Data.Heap as Heap
import Data.List(unfoldr, sort, foldl')
import qualified Data.Maybe
import Data.PackedSequence(PackedSequence)
import qualified Data.PackedSequence as PackedSequence
import Data.Serialize
import Data.Ord

-- | A queue of batches.
newtype Queue a = Queue (Heap.Heap (Best a))

-- | The type of batches must be a member of this class.
class Ord (Entry a) => Batch a where
  -- | Each batch can have an associated label,
  -- which is specified when calling 'insert'.
  -- A label represents a piece of information which is
  -- shared in common between all entries in a batch,
  -- and which might be used to store that batch more
  -- efficiently. 
  -- Labels are optional, and by default @Label a = ()@.
  type Label a

  -- | Individual entries in the batch.
  type Entry a

  -- | Given a label, and a non-empty list of entries,
  -- sorted in ascending order, produce a list of batches.
  makeBatch :: Label a -> [Entry a] -> [a]

  -- | Remove the smallest entry from a batch.
  unconsBatch :: a -> (Entry a, Maybe a)
  
  -- | Return the label of a batch.
  batchLabel :: a -> Label a

  -- | Compute the size of a batch. Used in 'size'.
  -- The default implementation works by repeatedly calling
  -- 'unconsBatch'.
  batchSize :: a -> Int
  batchSize = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Batch a => a -> [Entry a]
unbatch

  type Label a = ()

-- A newtype wrapper for batches which compares the smallest entry.
newtype Best a = Best { forall a. Best a -> a
unBest :: a }
instance Batch a => Eq (Best a) where Best a
x == :: Best a -> Best a -> Bool
== Best a
y = forall a. Ord a => a -> a -> Ordering
compare Best a
x Best a
y forall a. Eq a => a -> a -> Bool
== Ordering
EQ
instance Batch a => Ord (Best a) where
  {-# INLINEABLE compare #-}
  compare :: Best a -> Best a -> Ordering
compare = forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Batch a => a -> (Entry a, Maybe a)
unconsBatch forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Best a -> a
unBest)

-- | Convert a batch into a list of entries.
{-# INLINEABLE unbatch #-}
unbatch :: Batch a => a -> [Entry a]
unbatch :: forall a. Batch a => a -> [Entry a]
unbatch a
batch = forall b a. (b -> Maybe (a, b)) -> b -> [a]
unfoldr (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Batch a => a -> (Entry a, Maybe a)
unconsBatch) (forall a. a -> Maybe a
Just a
batch)

-- | The empty queue.
empty :: Queue a
empty :: forall a. Queue a
empty = forall a. Heap (Best a) -> Queue a
Queue forall a. Heap a
Heap.empty

-- | Add entries to the queue.
{-# INLINEABLE insert #-}
insert :: forall a. Batch a => Label a -> [Entry a] -> Queue a -> Queue a
insert :: forall a. Batch a => Label a -> [Entry a] -> Queue a -> Queue a
insert Label a
_ [] Queue a
q = Queue a
q
insert Label a
l [Entry a]
is (Queue Heap (Best a)
q) =
  forall a. Heap (Best a) -> Queue a
Queue forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall a. Ord a => a -> Heap a -> Heap a
Heap.insert forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Best a
Best)) Heap (Best a)
q (forall a. Batch a => Label a -> [Entry a] -> [a]
makeBatch Label a
l (forall a. Ord a => [a] -> [a]
sort [Entry a]
is))

-- | Remove the minimum entry from the queue.
{-# INLINEABLE removeMin #-}
removeMin :: Batch a => Queue a -> Maybe (Entry a, Queue a)
removeMin :: forall a. Batch a => Queue a -> Maybe (Entry a, Queue a)
removeMin Queue a
q = forall a.
Batch a =>
(Label a -> Bool) -> Queue a -> Maybe (Entry a, Queue a)
removeMinFilter (forall a b. a -> b -> a
const Bool
True) Queue a
q

-- | Remove the minimum entry from the queue, discarding any
-- batches that do not satisfy the predicate.
{-# INLINEABLE removeMinFilter #-}
removeMinFilter :: Batch a => (Label a -> Bool) -> Queue a -> Maybe (Entry a, Queue a)
removeMinFilter :: forall a.
Batch a =>
(Label a -> Bool) -> Queue a -> Maybe (Entry a, Queue a)
removeMinFilter Label a -> Bool
ok (Queue Heap (Best a)
q) = do
  (Best a
batch, Heap (Best a)
q) <- forall a. Ord a => Heap a -> Maybe (a, Heap a)
Heap.removeMin Heap (Best a)
q
  if Bool -> Bool
not (Label a -> Bool
ok (forall a. Batch a => a -> Label a
batchLabel a
batch)) then forall a.
Batch a =>
(Label a -> Bool) -> Queue a -> Maybe (Entry a, Queue a)
removeMinFilter Label a -> Bool
ok (forall a. Heap (Best a) -> Queue a
Queue Heap (Best a)
q) else
    case forall a. Batch a => a -> (Entry a, Maybe a)
unconsBatch a
batch of
      (Entry a
entry, Just a
batch') ->
        forall a. a -> Maybe a
Just (Entry a
entry, forall a. Heap (Best a) -> Queue a
Queue (forall a. Ord a => a -> Heap a -> Heap a
Heap.insert (forall a. a -> Best a
Best a
batch') Heap (Best a)
q))
      (Entry a
entry, Maybe a
Nothing) ->
        forall a. a -> Maybe a
Just (Entry a
entry, forall a. Heap (Best a) -> Queue a
Queue Heap (Best a)
q)

-- | Map a function over all entries.
-- The function must preserve the label of each batch,
-- and must not split existing batches into two.
{-# INLINEABLE mapMaybe #-}
mapMaybe :: Batch a => (Entry a -> Maybe (Entry a)) -> Queue a -> Queue a
mapMaybe :: forall a.
Batch a =>
(Entry a -> Maybe (Entry a)) -> Queue a -> Queue a
mapMaybe Entry a -> Maybe (Entry a)
f (Queue Heap (Best a)
q) = forall a. Heap (Best a) -> Queue a
Queue (forall a b. Ord b => (a -> Maybe b) -> Heap a -> Heap b
Heap.mapMaybe Best a -> Maybe (Best a)
g Heap (Best a)
q)
  where
    g :: Best a -> Maybe (Best a)
g (Best a
batch) =
      case forall a b. (a -> Maybe b) -> [a] -> [b]
Data.Maybe.mapMaybe Entry a -> Maybe (Entry a)
f (forall a. Batch a => a -> [Entry a]
unbatch a
batch) of
        [] -> forall a. Maybe a
Nothing
        [Entry a]
is ->
          case forall a. Batch a => Label a -> [Entry a] -> [a]
makeBatch (forall a. Batch a => a -> Label a
batchLabel a
batch) (forall a. Ord a => [a] -> [a]
sort [Entry a]
is) of
            [] -> forall a. Maybe a
Nothing
            [a
batch'] -> forall a. a -> Maybe a
Just (forall a. a -> Best a
Best a
batch')
            [a]
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"multiple batches produced"

-- | Convert a queue into a list of batches, in unspecified order.
{-# INLINEABLE toBatches #-}
toBatches :: Queue a -> [a]
toBatches :: forall a. Queue a -> [a]
toBatches (Queue Heap (Best a)
q) = forall a b. (a -> b) -> [a] -> [b]
map forall a. Best a -> a
unBest (forall a. Heap a -> [a]
Heap.toList Heap (Best a)
q)

-- | Convert a queue into a list of entries, in unspecified order.
{-# INLINEABLE toList #-}
toList :: Batch a => Queue a -> [Entry a]
toList :: forall a. Batch a => Queue a -> [Entry a]
toList Queue a
q = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall a. Batch a => a -> [Entry a]
unbatch (forall a. Queue a -> [a]
toBatches Queue a
q)

{-# INLINEABLE size #-}
size :: Batch a => Queue a -> Int
size :: forall a. Batch a => Queue a -> Int
size = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall a. Batch a => a -> Int
batchSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Queue a -> [a]
toBatches

-- | A "standard" type of batches. By using @Queue (StandardBatch a)@,
-- you will get a queue where entries have type @a@ and labels have
-- type @()@.
data StandardBatch a =
  StandardBatch {
    forall a. StandardBatch a -> a
batch_best :: !a,
    forall a. StandardBatch a -> PackedSequence a
batch_rest :: {-# UNPACK #-} !(PackedSequence a) }

instance Ord a => Eq (StandardBatch a) where
  StandardBatch a
x == :: StandardBatch a -> StandardBatch a -> Bool
== StandardBatch a
y = forall a. Ord a => a -> a -> Ordering
compare StandardBatch a
x StandardBatch a
y forall a. Eq a => a -> a -> Bool
== Ordering
EQ
instance Ord a => Ord (StandardBatch a) where
  compare :: StandardBatch a -> StandardBatch a -> Ordering
compare = forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing forall a. StandardBatch a -> a
batch_best

instance (Ord a, Serialize a) => Batch (StandardBatch a) where
  type Label (StandardBatch a) = ()
  type Entry (StandardBatch a) = a

  makeBatch :: Label (StandardBatch a)
-> [Entry (StandardBatch a)] -> [StandardBatch a]
makeBatch Label (StandardBatch a)
_ (Entry (StandardBatch a)
x:[Entry (StandardBatch a)]
xs) = [forall a. a -> PackedSequence a -> StandardBatch a
StandardBatch Entry (StandardBatch a)
x (forall a. Serialize a => [a] -> PackedSequence a
PackedSequence.fromList [Entry (StandardBatch a)]
xs)]
  unconsBatch :: StandardBatch a
-> (Entry (StandardBatch a), Maybe (StandardBatch a))
unconsBatch StandardBatch{a
PackedSequence a
batch_rest :: PackedSequence a
batch_best :: a
batch_rest :: forall a. StandardBatch a -> PackedSequence a
batch_best :: forall a. StandardBatch a -> a
..} =
    (a
batch_best,
     case forall a.
Serialize a =>
PackedSequence a -> Maybe (a, PackedSequence a)
PackedSequence.uncons PackedSequence a
batch_rest of
       Maybe (a, PackedSequence a)
Nothing -> forall a. Maybe a
Nothing
       Just (a
x, PackedSequence a
xs) -> forall a. a -> Maybe a
Just (forall a. a -> PackedSequence a -> StandardBatch a
StandardBatch a
x PackedSequence a
xs))
  batchLabel :: StandardBatch a -> Label (StandardBatch a)
batchLabel StandardBatch a
_ = ()
  batchSize :: StandardBatch a -> Int
batchSize StandardBatch{a
PackedSequence a
batch_rest :: PackedSequence a
batch_best :: a
batch_rest :: forall a. StandardBatch a -> PackedSequence a
batch_best :: forall a. StandardBatch a -> a
..} = Int
1 forall a. Num a => a -> a -> a
+ forall a. PackedSequence a -> Int
PackedSequence.size PackedSequence a
batch_rest