module Hydrogen.MultiMap (
    MultiMap
  , empty
  , null
  , keys
  , elems
  , numKeys
  , numElems
  , lookup
  , member
  , insert
  , delete
  , update
  , adjust
  , toMap
  , fromMap
  , fromList
  , fromList'
  , fromSet
  , toList
  , toList'
  , union
  ) where

import Prelude hiding (lookup, foldr, null)

import Data.Foldable (Foldable, foldr)
import Data.Traversable (Traversable)
import Data.Typeable (Typeable)

import GHC.Generics (Generic)

import Data.Map (Map)
import Data.Set (Set)

import qualified Data.Map as Map
import qualified Data.List as List

data MultiMap k v = MultiMap (Map k [v]) Int
  deriving (Eq, Ord, Functor, Foldable, Traversable, Generic, Typeable)

instance (Show k, Show v) => Show (MultiMap k v) where
    show (MultiMap m _) = show m

count :: Map k [v] -> Int
count = foldr (\vs s -> length vs + s) 0

empty :: MultiMap k v
empty = MultiMap Map.empty 0

numKeys :: MultiMap k v -> Int
numKeys (MultiMap m _) = Map.size m

numElems :: MultiMap k v -> Int
numElems (MultiMap _ s) = s

null :: MultiMap k v -> Bool
null (MultiMap m _) = Map.null m

keys :: MultiMap k v -> [k]
keys (MultiMap m _) = Map.keys m

elems :: MultiMap k v -> [[v]]
elems (MultiMap m _) = Map.elems m

lookup :: Ord k => k -> MultiMap k v -> [v]
lookup k (MultiMap m _) = maybe [] id $ Map.lookup k m

member :: Ord k => k -> MultiMap k v -> Bool
member k = not . List.null . lookup k

insert :: Ord k => k -> v -> MultiMap k v -> MultiMap k v
insert k v mm@(MultiMap m s) = MultiMap (Map.insert k set' m) s'
  where
    set = lookup k mm
    set' = v : set
    s' = s - length set + length set'

delete :: Ord k => k -> MultiMap k v -> MultiMap k v
delete k mm@(MultiMap m s) = MultiMap (Map.delete k m) s'
  where
    s' = s - length (lookup k mm)

update :: Ord k => k -> [v] -> MultiMap k v -> MultiMap k v
update k vs mm@(MultiMap m s)
    | List.null vs = MultiMap (Map.delete k m) s'
    | otherwise = MultiMap (Map.insert k vs m) s'
  where
    s' = s - length (lookup k mm) + length vs

adjust :: Ord k => ([v] -> [v]) -> k -> MultiMap k v -> MultiMap k v
adjust f k mm@(MultiMap m s)
    | List.null set' = MultiMap (Map.delete k m) s'
    | otherwise = MultiMap (Map.insert k set' m) s'
  where
    set = lookup k mm
    set' = f set
    s' = s - length set + length set'

toMap :: MultiMap k v -> Map k [v]
toMap (MultiMap m _) = m

fromMap :: Map k [v] -> MultiMap k v
fromMap m = MultiMap m (count m)

toList :: MultiMap k v -> [(k, [v])]
toList (MultiMap m _) = Map.toList m

toList' :: MultiMap k v -> [(k, v)]
toList' = concat . map (\(k, vs) -> [(k, v) | v <- vs]) . toList

fromList :: Ord k => [(k, [v])] -> MultiMap k v
fromList xs = MultiMap (Map.fromList xs) (foldr (\x s -> length (snd x) + s) 0 xs)

fromList' :: Ord k => [(k, v)] -> MultiMap k v
fromList' = foldr (uncurry insert) empty

fromSet :: Ord k => (k -> [v]) -> Set k -> MultiMap k v
fromSet f s = MultiMap m (count m)
  where
    m = Map.fromSet f s

union :: Ord k => MultiMap k v -> MultiMap k v -> MultiMap k v
union (MultiMap m1 s1) (MultiMap m2 s2) = MultiMap (Map.unionWith (++) m1 m2) (s1 + s2)