{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeOperators              #-}
module Data.UnionIntMap.Internal where

import           Control.Applicative hiding (empty)
import           Control.Monad
import           Data.IntMap         (IntMap)
import qualified Data.IntMap         as IM
import qualified Data.List           as L
import           Data.Maybe
import           Data.Monoid
import           Data.Extensible.Inclusion (Include, spread)
import           Data.Extensible.Plain (K0(..))
import           Data.Extensible.Internal  (Member)
import           Data.Extensible.Sum
import           Internal
import           Prelude             hiding (lookup)

newtype UnionIntMap r = UnionIntMap (IntMap (K0 :| r))
  deriving (Monoid)

type Key = Int

null :: UnionIntMap r -> Bool
null (UnionIntMap m) = IM.null m
{-# INLINE null #-}

empty :: UnionIntMap r
empty = UnionIntMap IM.empty
{-# INLINE empty #-}

size :: UnionIntMap r -> Int
size (UnionIntMap m) = IM.size m
{-# INLINE size #-}

member :: Key -> UnionIntMap r -> Bool
member k (UnionIntMap m) = IM.member k m
{-# INLINE member #-}

notMember :: Key -> UnionIntMap r -> Bool
notMember k = not . member k
{-# INLINE notMember #-}

singleton :: Member as a => Key -> a -> UnionIntMap as
singleton k x = insert k x empty
{-# INLINE singleton #-}

liftUM :: (IntMap (K0 :| r) -> IntMap (K0 :| s)) -> UnionIntMap r -> UnionIntMap s
liftUM f (UnionIntMap m) = UnionIntMap $ f m
{-# INLINE liftUM #-}

liftUM2 :: (IntMap (K0 :| r) -> IntMap (K0 :| s) -> IntMap (K0 :| t)) -> UnionIntMap r -> UnionIntMap s -> UnionIntMap t
liftUM2 f (UnionIntMap m1) (UnionIntMap m2) = UnionIntMap $ f m1 m2
{-# INLINE liftUM2 #-}

lookup :: Member as a => Key -> UnionIntMap as -> Maybe a
lookup k (UnionIntMap m) = IM.lookup k m >>= retract
{-# INLINE lookup #-}

lookupU :: Key -> UnionIntMap r -> Maybe (K0 :| r)
lookupU k (UnionIntMap m) = IM.lookup k m
{-# INLINE lookupU #-}

find :: Member as a => Key -> UnionIntMap as -> a
find k um = flip fromMaybe (lookup k um)
  $ error "UnionIntMap.find: given key is not an element in the map, or is not associated a value of the type expected."
{-# INLINE find #-}

findU :: Key -> UnionIntMap r -> K0 :| r
findU k um = flip fromMaybe (lookupU k um)
  $ error "UnionIntMap.findU: given key is not an element in the map, or is not associated a value of the type expected."
{-# INLINE findU #-}

findWithDefault :: Member as a => a -> Key -> UnionIntMap as -> a
findWithDefault def k um = fromMaybe def $ lookup k um
{-# INLINE findWithDefault #-}

(!) :: Member as a => UnionIntMap as -> Key -> a
um ! k = flip fromMaybe (lookup k um)
  $ error "UnionIntMap.!: given key is not an element in the map, or is not associated a value of the type expected."
{-# INLINE (!) #-}

insert :: Member as a => Key -> a -> UnionIntMap as -> UnionIntMap as
insert k x = liftUM (IM.insert k (embed (K0 x)))
{-# INLINE insert #-}

insertWith :: Member as a => (a -> a -> a) -> Key -> a -> UnionIntMap as -> UnionIntMap as
insertWith f = insertWithKey (\_ y z -> f y z)
{-# INLINE insertWith #-}

insertWithKey :: Member as a => (Key -> a -> a -> a) -> Key -> a -> UnionIntMap as -> UnionIntMap as
insertWithKey f k x = liftUM (IM.insertWithKey go k (embed (K0 x)))
  where
    go k' _ s = maybe (embed (K0 x)) (embed . K0) $ f <$> pure k' <*> return x <*> retract s
    {-# INLINE go #-}
{-# INLINE insertWithKey #-}

delete :: Key -> UnionIntMap as -> UnionIntMap as
delete = liftUM . IM.delete
{-# INLINE delete #-}

adjust :: Member as a => (a -> a) -> Key -> UnionIntMap as -> UnionIntMap as
adjust f = adjustWithKey (\_ x -> f x)
{-# INLINE adjust #-}

adjustWithKey :: Member as a => (Key -> a -> a) -> Key -> UnionIntMap as -> UnionIntMap as
adjustWithKey f k = liftUM (IM.adjust (modify (f k)) k)
{-# INLINE adjustWithKey #-}

update :: Member as a => (a -> Maybe a) -> Key -> UnionIntMap as -> UnionIntMap as
update f = updateWithKey (\_ x -> f x)
{-# INLINE update #-}

updateWithKey :: Member as a => (Key -> a -> Maybe a) -> Key -> UnionIntMap as -> UnionIntMap as
updateWithKey f k = liftUM (IM.update (retract >=> f k >=> return . embed . K0) k)
{-# INLINE updateWithKey #-}

union :: UnionIntMap r -> UnionIntMap r -> UnionIntMap r
union = liftUM2 IM.union
{-# INLINE union #-}

unions :: [UnionIntMap r] -> UnionIntMap r
unions = L.foldl' union empty
{-# INLINE unions #-}

difference :: UnionIntMap r -> UnionIntMap r -> UnionIntMap r
difference = liftUM2 IM.difference
{-# INLINE difference #-}

(\\) :: UnionIntMap r -> UnionIntMap r -> UnionIntMap r
(\\) = difference
{-# INLINE (\\) #-}

intersection :: UnionIntMap r -> UnionIntMap r -> UnionIntMap r
intersection = liftUM2 IM.intersection
{-# INLINE intersection #-}

keys :: UnionIntMap r -> [Key]
keys (UnionIntMap m)  = IM.keys m
{-# INLINE keys #-}

map :: Member as a => (a -> a) -> UnionIntMap as -> UnionIntMap as
map f = mapWithKey (\_ x -> f x)
{-# INLINE map #-}

mapWithKey :: Member as a => (Key -> a -> a) -> UnionIntMap as -> UnionIntMap as
mapWithKey f = liftUM (IM.mapWithKey (modify . f))
{-# INLINE mapWithKey #-}

mapU :: (K0 :| r -> K0 :| s) -> UnionIntMap r -> UnionIntMap s
mapU f = mapWithKeyU (\_ u -> f u)
{-# INLINE mapU #-}

mapWithKeyU :: (Key -> K0 :| r -> K0 :| s) -> UnionIntMap r -> UnionIntMap s
mapWithKeyU = liftUM . IM.mapWithKey
{-# INLINE mapWithKeyU #-}

mapU' :: (K0 :| r -> a) -> UnionIntMap r -> IntMap a
mapU' f = mapWithKeyU' (\_ u -> f u)
{-# INLINE mapU' #-}

mapWithKeyU' :: (Key -> K0 :| r -> a) -> UnionIntMap r -> IntMap a
mapWithKeyU' f (UnionIntMap m) = IM.mapWithKey f m
{-# INLINE mapWithKeyU' #-}

rebuild :: Include s r => UnionIntMap r -> UnionIntMap s
rebuild = mapU spread
{-# INLINE rebuild #-}

filterU :: (K0 :| r -> Bool) -> UnionIntMap r -> UnionIntMap r
filterU p = filterWithKeyU (\_ u -> p u)
{-# INLINE filterU #-}

filterWithKeyU :: (Key -> K0 :| r -> Bool) -> UnionIntMap r -> UnionIntMap r
filterWithKeyU = liftUM . IM.filterWithKey
{-# INLINE filterWithKeyU #-}

foldrU :: (K0 :| r -> b -> b) -> b -> UnionIntMap r -> b
foldrU f = foldrWithKeyU (\_ u z -> f u z)
{-# INLINE foldrU #-}

foldrWithKeyU :: (Key -> K0 :| r -> b -> b) -> b -> UnionIntMap r -> b
foldrWithKeyU f z (UnionIntMap m) = IM.foldrWithKey f z m
{-# INLINE foldrWithKeyU #-}

foldlU' :: (a -> K0 :| r -> a) -> a -> UnionIntMap r -> a
foldlU' f = foldlWithKeyU' (\z _ u -> f z u)
{-# INLINE foldlU' #-}

foldlWithKeyU' :: (a -> Key -> K0 :| r -> a) -> a -> UnionIntMap r -> a
foldlWithKeyU' f z (UnionIntMap m) = IM.foldlWithKey' f z m
{-# INLINE foldlWithKeyU' #-}

showTree :: Show (K0 :| r) => UnionIntMap r -> String
showTree (UnionIntMap m) = IM.showTree m
{-# INLINE showTree #-}