{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
module Data.UnionMap.Internal where

import           Control.Applicative hiding (empty)
import           Control.Monad
import qualified Data.List           as L
import           Data.Map            (Map)
import qualified Data.Map            as M
import           Data.Maybe
import           Data.Monoid
import           Data.OpenUnion
import           Prelude             hiding (lookup)

newtype UnionMap k r = UnionMap (Map k (Union r))
  deriving (Monoid, Show)

null :: UnionMap k r -> Bool
null (UnionMap m) = M.null m
{-# INLINE null #-}

empty :: UnionMap k r
empty = UnionMap M.empty
{-# INLINE empty #-}

size :: UnionMap k r -> Int
size (UnionMap m) = M.size m
{-# INLINE size #-}

member :: Ord k => k -> UnionMap k r -> Bool
member k (UnionMap m) = M.member k m
{-# INLINE member #-}

notMember :: Ord k => k -> UnionMap k r -> Bool
notMember k = not . member k
{-# INLINE notMember #-}

singleton :: (Ord k, Member a as) => k -> a -> UnionMap k as
singleton k x = insert k x empty
{-# INLINE singleton #-}

liftUM :: (Map k (Union r) -> Map k (Union s)) -> UnionMap k r -> UnionMap k s
liftUM f (UnionMap m) = UnionMap $ f m
{-# INLINE liftUM #-}

liftUM2 :: (Map k (Union r) -> Map k (Union s) -> Map k (Union t)) -> UnionMap k r -> UnionMap k s -> UnionMap k t
liftUM2 f (UnionMap m1) (UnionMap m2) = UnionMap $ f m1 m2
{-# INLINE liftUM2 #-}

lookup :: (Ord k, Member a as) => k -> UnionMap k as -> Maybe a
lookup k (UnionMap m) = M.lookup k m >>= retractU
{-# INLINE lookup #-}

lookupU :: Ord k => k -> UnionMap k r -> Maybe (Union r)
lookupU k (UnionMap m) = M.lookup k m
{-# INLINE lookupU #-}

find :: (Ord k, Member a as) => k -> UnionMap k as -> a
find k um = flip fromMaybe (lookup k um)
  $ error "UnionMap.find: given key is not an element in the map, or is not associated a value of the type expected."
{-# INLINE find #-}

findU :: Ord k => k -> UnionMap k r -> Union r
findU k um = flip fromMaybe (lookupU k um)
  $ error "UnionMap.findU: given key is not an element in the map, or is not associated a value of the type expected."
{-# INLINE findU #-}

findWithDefault :: (Ord k, Member a as) => a -> k -> UnionMap k as -> a
findWithDefault def k um = fromMaybe def $ lookup k um
{-# INLINE findWithDefault #-}

(!) :: (Ord k, Member a as) => UnionMap k as -> k -> a
um ! k = flip fromMaybe (lookup k um)
  $ error "UnionMap.!: given key is not an element in the map, or is not associated a value of the type expected."
{-# INLINE (!) #-}

insert :: (Ord k, Member a as) => k -> a -> UnionMap k as -> UnionMap k as
insert k x = liftUM (M.insert k (liftU x))
{-# INLINE insert #-}

insertWith :: (Ord k, Member a as) => (a -> a -> a) -> k -> a -> UnionMap k as -> UnionMap k as
insertWith f = insertWithKey (\_ y z -> f y z)
{-# INLINE insertWith #-}

insertWithKey :: (Ord k, Member a as) => (k -> a -> a -> a) -> k -> a -> UnionMap k as -> UnionMap k as
insertWithKey f k x = liftUM (M.insertWithKey go k (liftU x))
  where
    go k' _ s = maybe (liftU x) liftU $ f <$> pure k' <*> return x <*> retractU s
    {-# INLINE go #-}
{-# INLINE insertWithKey #-}

delete :: Ord k => k -> UnionMap k as -> UnionMap k as
delete = liftUM . M.delete
{-# INLINE delete #-}

adjust :: (Ord k, Member a as) => (a -> a) -> k -> UnionMap k as -> UnionMap k as
adjust f = adjustWithKey (\_ x -> f x)
{-# INLINE adjust #-}

adjustWithKey :: (Ord k, Member a as) => (k -> a -> a) -> k -> UnionMap k as -> UnionMap k as
adjustWithKey f k = liftUM (M.adjust (hoistU (f k)) k)
{-# INLINE adjustWithKey #-}

update :: (Ord k, Member a as) => (a -> Maybe a) -> k -> UnionMap k as -> UnionMap k as
update f = updateWithKey (\_ x -> f x)
{-# INLINE update #-}

updateWithKey :: (Ord k, Member a as) => (k -> a -> Maybe a) -> k -> UnionMap k as -> UnionMap k as
updateWithKey f k = liftUM (M.update (retractU >=> f k >=> return . liftU) k)
{-# INLINE updateWithKey #-}

union :: Ord k => UnionMap k r -> UnionMap k r -> UnionMap k r
union = liftUM2 M.union
{-# INLINE union #-}

unions :: Ord k => [UnionMap k r] -> UnionMap k r
unions = L.foldl' union empty
{-# INLINE unions #-}

difference :: Ord k => UnionMap k r -> UnionMap k r -> UnionMap k r
difference = liftUM2 M.difference
{-# INLINE difference #-}

(\\) :: Ord k => UnionMap k r -> UnionMap k r -> UnionMap k r
(\\) = difference
{-# INLINE (\\) #-}

intersection :: Ord k => UnionMap k r -> UnionMap k r -> UnionMap k r
intersection = liftUM2 M.intersection
{-# INLINE intersection #-}

keys :: UnionMap k r -> [k]
keys (UnionMap m)  = M.keys m
{-# INLINE keys #-}

map :: Member a as => (a -> a) -> UnionMap k as -> UnionMap k as
map f = mapWithKey (\_ x -> f x)
{-# INLINE map #-}

mapWithKey :: Member a as => (k -> a -> a) -> UnionMap k as -> UnionMap k as
mapWithKey f = liftUM (M.mapWithKey (hoistU . f))
{-# INLINE mapWithKey #-}

mapU :: (Union r -> Union s) -> UnionMap k r -> UnionMap k s
mapU f = mapWithKeyU (\_ u -> f u)
{-# INLINE mapU #-}

mapWithKeyU :: (k -> Union r -> Union s) -> UnionMap k r -> UnionMap k s
mapWithKeyU = liftUM . M.mapWithKey
{-# INLINE mapWithKeyU #-}

mapU' :: (Union r -> a) -> UnionMap k r -> Map k a
mapU' f = mapWithKeyU' (\_ u -> f u)
{-# INLINE mapU' #-}

mapWithKeyU' :: (k -> Union r -> a) -> UnionMap k r -> Map k a
mapWithKeyU' f (UnionMap m) = M.mapWithKey f m
{-# INLINE mapWithKeyU' #-}

rebuild :: Include r s => UnionMap k r -> UnionMap k s
rebuild = mapU reunion
{-# INLINE rebuild #-}

filterU :: (Union r -> Bool) -> UnionMap k r -> UnionMap k r
filterU p = filterWithKeyU (\_ u -> p u)
{-# INLINE filterU #-}

filterWithKeyU :: (k -> Union r -> Bool) -> UnionMap k r -> UnionMap k r
filterWithKeyU = liftUM . M.filterWithKey
{-# INLINE filterWithKeyU #-}

foldrU :: (Union r -> b -> b) -> b -> UnionMap k r -> b
foldrU f = foldrWithKeyU (\_ u z -> f u z)
{-# INLINE foldrU #-}

foldrWithKeyU :: (k -> Union r -> b -> b) -> b -> UnionMap k r -> b
foldrWithKeyU f z (UnionMap m) = M.foldrWithKey f z m
{-# INLINE foldrWithKeyU #-}

foldlU' :: (a -> Union r -> a) -> a -> UnionMap k r -> a
foldlU' f = foldlWithKeyU' (\z _ u -> f z u)
{-# INLINE foldlU' #-}

foldlWithKeyU' :: (a -> k -> Union r -> a) -> a -> UnionMap k r -> a
foldlWithKeyU' f z (UnionMap m) = M.foldlWithKey' f z m
{-# INLINE foldlWithKeyU' #-}

showTree :: (Show k, Show (Union r)) => UnionMap k r -> String
showTree (UnionMap m) = M.showTree m
{-# INLINE showTree #-}