{-# LANGUAGE CPP #-}

-- | A simple overlay over Data.Map to manage unordered sets with duplicates.

module Agda.Utils.Bag where

import Prelude hiding (null, map)

import Control.Applicative hiding (empty)
import Text.Show.Functions ()

import Data.Foldable (Foldable(foldMap))
import Data.Functor.Identity
import qualified Data.List as List
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Semigroup
import qualified Data.Set as Set
import Data.Traversable

import Agda.Utils.Functor

#include "undefined.h"
import Agda.Utils.Impossible

-- | A set with duplicates.
--   Faithfully stores elements which are equal with regard to (==).
newtype Bag a = Bag
  { bag :: Map a [a]
      -- ^ The list contains all occurrences of @a@ (not just the duplicates!).
      --   Hence, the invariant: the list is never empty.
  }
  deriving (Eq, Ord)
  -- The list contains all occurrences of @a@ (not just the duplicates!).
  -- Hence the invariant: the list is never empty.
  --
  -- This is slightly wasteful, but much easier to implement
  -- in terms of @Map@ as the alternative, which is to store
  -- only the duplicates in the list.
  -- See, e.g., implementation of 'union' which would be impossible
  -- to do in the other representation.  We would need a
  -- 'Map.unionWithKey' that passes us *both* keys.
  -- But Map works under the assumption that Eq for keys is identity,
  -- it does not honor information in keys that goes beyond Ord.

------------------------------------------------------------------------
-- * Query
------------------------------------------------------------------------

-- | Is the bag empty?
null :: Bag a -> Bool
null = Map.null . bag

-- | Number of elements in the bag.  Duplicates count. O(n).
size :: Bag a -> Int
size = getSum . foldMap (Sum . length) . bag

-- | @(bag ! a)@ finds all elements equal to @a@.  O(log n).
--   Total function, returns @[]@ if none are.
(!) :: Ord a => Bag a -> a -> [a]
Bag b ! a = Map.findWithDefault [] a b

-- | O(log n).
member :: Ord a => a -> Bag a -> Bool
member a = not . notMember a

-- | O(log n).
notMember :: Ord a => a -> Bag a -> Bool
notMember a b = List.null (b ! a)

-- | Return the multiplicity of the given element. O(log n + count _ _).
count :: Ord a => a -> Bag a -> Int
count a b = length (b ! a)

------------------------------------------------------------------------
-- * Construction
------------------------------------------------------------------------

-- | O(1)
empty :: Bag a
empty = Bag $ Map.empty

-- | O(1)
singleton :: a -> Bag a
singleton a = Bag $ Map.singleton a [a]

union :: Ord a => Bag a -> Bag a -> Bag a
union (Bag b) (Bag c) = Bag $ Map.unionWith (++) b c

unions :: Ord a => [Bag a] -> Bag a
unions = Bag . Map.unionsWith (++)  . List.map bag

-- | @insert a b = union b (singleton a)@
insert :: Ord a => a -> Bag a -> Bag a
insert a = Bag . Map.insertWith (++) a [a] . bag

-- | @fromList = unions . map singleton@
fromList :: Ord a => [a] -> Bag a
fromList = Bag . Map.fromListWith (++) . List.map (\ a -> (a,[a]))

------------------------------------------------------------------------
-- * Destruction
------------------------------------------------------------------------

-- | Returns the elements of the bag, grouped by equality (==).
groups :: Bag a -> [[a]]
groups = Map.elems . bag

-- | Returns the bag, with duplicates.
toList :: Bag a -> [a]
toList = concat . groups

-- | Returns the bag without duplicates.
keys :: Bag a -> [a]
keys = Map.keys . bag
-- Works because of the invariant!
-- keys = catMaybes . map headMaybe . Map.elems . bag
--   -- Map.keys does not work, as zero copies @(a,[])@
--   -- should count as not present in the bag.

-- | Returns the bag, with duplicates.
elems :: Bag a -> [a]
elems = toList

toAscList :: Bag a -> [a]
toAscList = toList

------------------------------------------------------------------------
-- * Traversal
------------------------------------------------------------------------

map :: Ord b => (a -> b) -> Bag a -> Bag b
map f = Bag . Map.fromListWith (++) . List.map ff . Map.elems . bag
  where
    ff (a : as) = (b, b : List.map f as) where b = f a
    ff []       = __IMPOSSIBLE__

traverse' :: forall a b m . (Applicative m, Ord b) =>
             (a -> m b) -> Bag a -> m (Bag b)
traverse' f = (Bag . Map.fromListWith (++)) <.> traverse trav . Map.elems . bag
  where
    trav :: [a] -> m (b, [b])
    trav (a : as) = (\ b bs -> (b, b:bs)) <$> f a <*> traverse f as
    trav []       = __IMPOSSIBLE__

------------------------------------------------------------------------
-- Instances
------------------------------------------------------------------------

instance Show a => Show (Bag a) where
  showsPrec _ (Bag b) = ("Agda.Utils.Bag.Bag (" ++) . showsPrec 0 b . (')':)

instance Ord a => Semigroup (Bag a) where
  (<>) = union

instance Ord a => Monoid (Bag a) where
  mempty  = empty
  mappend = (<>)
  mconcat = unions

instance Foldable Bag where
  foldMap f = foldMap f . toList

-- not a Functor (only works for 'Ord'ered types)
-- not Traversable (only works for 'Ord'ered types)