module Hydrogen.MultiMap (
    MultiMap
  , empty
  , null
  , keys
  , elems
  , numKeys
  , numElems
  , lookup
  , member
  , insert
  , delete
  , update
  , adjust
  , toMap
  , fromMap
  , fromList
  , fromList'
  , fromSet
  , toList
  , toList'
  , union
  ) where

import Prelude hiding (lookup, foldr, null)

import Data.Foldable (Foldable, foldr)
import Data.Traversable (Traversable)
import Data.Typeable (Typeable)

import GHC.Generics (Generic)

import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.List as List

type Map = Map.Map
type Set = Set.Set

data MultiMap k v = MultiMap (Map k [v]) Int
  deriving (Eq, Ord, Functor, Foldable, Traversable, Generic, Typeable)

instance (Show k, Show v) => Show (MultiMap k v) where
    show (MultiMap m _) = show m

count :: Map k [v] -> Int
count = foldr (\vs s -> length vs + s) 0

empty :: MultiMap k v
empty = MultiMap Map.empty 0

numKeys :: MultiMap k v -> Int
numKeys (MultiMap m _) = Map.size m

numElems :: MultiMap k v -> Int
numElems (MultiMap _ s) = s

null :: MultiMap k v -> Bool
null (MultiMap m _) = Map.null m

keys :: MultiMap k v -> [k]
keys (MultiMap m _) = Map.keys m

elems :: MultiMap k v -> [[v]]
elems (MultiMap m _) = Map.elems m

lookup :: Ord k => k -> MultiMap k v -> [v]
lookup k (MultiMap m _) = maybe [] id $ Map.lookup k m

member :: Ord k => k -> MultiMap k v -> Bool
member k = not . List.null . lookup k

insert :: Ord k => k -> v -> MultiMap k v -> MultiMap k v
insert k v mm@(MultiMap m s) = MultiMap (Map.insert k set' m) s'
  where
    set = lookup k mm
    set' = v : set
    s' = s - length set + length set'

delete :: Ord k => k -> MultiMap k v -> MultiMap k v
delete k mm@(MultiMap m s) = MultiMap (Map.delete k m) s'
  where
    s' = s - length (lookup k mm)

update :: Ord k => k -> [v] -> MultiMap k v -> MultiMap k v
update k vs mm@(MultiMap m s)
    | List.null vs = MultiMap (Map.delete k m) s'
    | otherwise = MultiMap (Map.insert k vs m) s'
  where
    s' = s - length (lookup k mm) + length vs

adjust :: Ord k => ([v] -> [v]) -> k -> MultiMap k v -> MultiMap k v
adjust f k mm@(MultiMap m s)
    | List.null set' = MultiMap (Map.delete k m) s'
    | otherwise = MultiMap (Map.insert k set' m) s'
  where
    set = lookup k mm
    set' = f set
    s' = s - length set + length set'

toMap :: MultiMap k v -> Map k [v]
toMap (MultiMap m _) = m

fromMap :: Map k [v] -> MultiMap k v
fromMap m = MultiMap m (count m)

toList :: MultiMap k v -> [(k, [v])]
toList (MultiMap m _) = Map.toList m

toList' :: MultiMap k v -> [(k, v)]
toList' = concat . map (\(k, vs) -> [(k, v) | v <- vs]) . toList

fromList :: Ord k => [(k, [v])] -> MultiMap k v
fromList xs = MultiMap (Map.fromList xs) (foldr (\x s -> length (snd x) + s) 0 xs)

fromList' :: Ord k => [(k, v)] -> MultiMap k v
fromList' = foldr (uncurry insert) empty

fromSet :: forall k. forall v. Ord k => (k -> [v]) -> Set k -> MultiMap k v
fromSet f s = MultiMap m (count m)
  where
    m :: Map k [v]
#if MIN_VERSION_containers(5,0,0)
    m = Map.fromSet f s
#else
    m = Map.fromList $ zip xs (map f xs)
    xs = Set.toList s
#endif

union :: Ord k => MultiMap k v -> MultiMap k v -> MultiMap k v
union (MultiMap m1 s1) (MultiMap m2 s2) = MultiMap (Map.unionWith (++) m1 m2) (s1 + s2)