{-# LANGUAGE TypeFamilies, FlexibleContexts #-}

module Data.TrieMap (
	-- * Map type
	TKey,
	TMap,
	-- * Operators
	(!),
	(\\),
	-- * Query
	null,
	size,
	member,
	notMember,
	lookup,
	findWithDefault,
	-- * Construction
	empty,
	showMap,
	singleton,
	-- ** Insertion
	insert,
	insertWith,
	insertWithKey,
	-- ** Delete/Update
	delete,
	adjust,
	adjustWithKey,
	update,
	updateWithKey,
	alter,
	-- * Combine
	-- ** Union
	union,
	unionWith,
	unionWithKey,
	unionMaybeWith,
	unionMaybeWithKey,
	-- ** Difference
	difference,
	differenceWith,
	differenceWithKey,
	-- ** Intersection
	intersection,
	intersectionWith,
	intersectionWithKey,
	intersectionMaybeWith,
	intersectionMaybeWithKey,
	-- * Traversal
	-- ** Map
	map,
	mapWithKey,
	mapKeys,
	mapKeysWith,
	mapKeysMonotonic,
	-- ** Traverse
	traverseWithKey,
	-- ** Fold
	fold,
	foldWithKey,
	foldrWithKey,
	foldlWithKey,
	-- * Conversion
	elems,
	keys,
	assocs,
	-- ** Lists
	fromList,
	fromListWith,
	fromListWithKey,
	-- ** Ordered lists
	fromAscList,
	fromAscListWith,
	fromAscListWithKey,
	fromDistinctAscList,
	-- * Filter
	filter,
	filterWithKey,
	partition,
	partitionWithKey,
	mapMaybe,
	mapMaybeWithKey,
	mapEither,
	mapEitherWithKey,
	split,
	splitLookup,
	-- * Submap
	isSubmapOf,
	isSubmapOfBy,
	-- * Min/Max
	findMin,
	findMax,
	deleteMin,
	deleteMax,
	deleteFindMin,
	deleteFindMax,
	updateMin,
	updateMax,
	updateMinWithKey,
	updateMaxWithKey,
	minView,
	maxView,
	minViewWithKey,
	maxViewWithKey
	) where

import Data.TrieMap.Class
import Data.TrieMap.Class.Instances()
import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative

import Control.Applicative hiding (empty)
import Control.Arrow
import Data.Maybe hiding (mapMaybe)
import Data.Monoid(First(..), Last(..))
-- import Data.Foldable
-- import Data.Traversable

-- import Generics.MultiRec.Base
import Data.TrieMap.Regular.Base
import Data.TrieMap.Regular.Sized
import GHC.Exts (build)

import Prelude hiding (lookup, foldr, null, map, filter)

instance (Show k, Show a, TKey k) => Show (TMap k a) where
	show m = "fromList " ++ show (assocs m)

instance (Eq k, TKey k, Eq a) => Eq (TMap k a) where
	m1 == m2 = assocs m1 == assocs m2

instance (Ord k, TKey k, Ord a) => Ord (TMap k a) where
	m1 `compare` m2 = assocs m1 `compare` assocs m2

-- newtype Elem a k = Elem {getElem :: a}
empty :: TKey k => TMap k a
empty = TMap emptyM

singleton :: TKey k => k -> a -> TMap k a
singleton k a = insert k a empty

null :: TKey k => TMap k a -> Bool
null (TMap m) = nullM m

lookup :: TKey k => k -> TMap k a -> Maybe a
lookup k (TMap m) = unK0 <$> lookupM (toRep k) m

findWithDefault :: TKey k => a -> k -> TMap k a -> a
findWithDefault a = fromMaybe a .: lookup

(!) :: TKey k => TMap k a -> k -> a
m ! k = fromMaybe (error "Element not found") (lookup k m)

alter :: TKey k => (Maybe a -> Maybe a) -> k -> TMap k a -> TMap k a
alter f k (TMap m) = TMap (alterM sizeK0 (fmap K0 . f . fmap unK0) (toRep k) m)

insert :: TKey k => k -> a -> TMap k a -> TMap k a
insert = insertWith const

insertWith :: TKey k => (a -> a -> a) -> k -> a -> TMap k a -> TMap k a
insertWith = insertWithKey . const

insertWithKey :: TKey k => (k -> a -> a -> a) -> k -> a -> TMap k a -> TMap k a
insertWithKey f k a = alter f' k where
	f' = Just . maybe a (f k a)

delete :: TKey k => k -> TMap k a -> TMap k a
delete = alter (const Nothing)

adjust :: TKey k => (a -> a) -> k -> TMap k a -> TMap k a
adjust = adjustWithKey . const

adjustWithKey :: TKey k => (k -> a -> a) -> k -> TMap k a -> TMap k a
adjustWithKey f = updateWithKey (Just .: f)

update :: TKey k => (a -> Maybe a) -> k -> TMap k a -> TMap k a
update f = alter (>>= f)

updateWithKey :: TKey k => (k -> a -> Maybe a) -> k -> TMap k a -> TMap k a
updateWithKey f k = update (f k) k

fold :: TKey k => (a -> b -> b) -> b -> TMap k a -> b
fold = foldWithKey . const

foldWithKey, foldrWithKey :: TKey k => (k -> a -> b -> b) -> b -> TMap k a -> b
foldWithKey f z (TMap m) = foldWithKeyM (\ k (K0 a) -> f (fromRep k) a) m z
foldrWithKey = foldWithKey

foldlWithKey :: TKey k => (b -> k -> a -> b) -> b -> TMap k a -> b
foldlWithKey f z (TMap m) = foldlWithKeyM (\ k z (K0 a) -> f z (fromRep k) a) m z

traverseWithKey :: (TKey k, Applicative f) => (k -> a -> f b) -> TMap k a -> f (TMap k b)
traverseWithKey f (TMap m) = TMap <$> traverseWithKeyM sizeK0 (\ k (K0 a) -> K0 <$> f (fromRep k) a) m

map :: TKey k => (a -> b) -> TMap k a -> TMap k b
map = fmap

mapWithKey :: TKey k => (k -> a -> b) -> TMap k a -> TMap k b
mapWithKey f (TMap m) = TMap (mapWithKeyM sizeK0 (\ k (K0 a) -> K0 (f (fromRep k) a)) m)

mapKeys :: (TKey k, TKey k') => (k -> k') -> TMap k a -> TMap k' a
mapKeys f m = fromList [(f k, a) | (k, a) <- assocs m]

mapKeysWith :: (TKey k, TKey k') => (a -> a -> a) -> (k -> k') -> TMap k a -> TMap k' a
mapKeysWith g f m = fromListWith g [(f k, a) | (k, a) <- assocs m]

mapKeysMonotonic :: (TKey k, TKey k') => (k -> k') -> TMap k a -> TMap k' a
mapKeysMonotonic f m = fromDistinctAscList [(f k, a) | (k, a) <- assocs m]

union :: TKey k => TMap k a -> TMap k a -> TMap k a
union = unionWith const

unionWith :: TKey k => (a -> a -> a) -> TMap k a -> TMap k a -> TMap k a
unionWith = unionWithKey . const

unionWithKey :: TKey k => (k -> a -> a -> a) -> TMap k a -> TMap k a -> TMap k a
unionWithKey f = unionMaybeWithKey (\ k a b -> Just (f k a b))

unionMaybeWith :: TKey k => (a -> a -> Maybe a) -> TMap k a -> TMap k a -> TMap k a
unionMaybeWith = unionMaybeWithKey . const

unionMaybeWithKey :: TKey k => (k -> a -> a -> Maybe a) -> TMap k a -> TMap k a -> TMap k a
unionMaybeWithKey f (TMap m1) (TMap m2) = TMap (unionM sizeK0 f' m1 m2) where
	f' k (K0 a) (K0 b) = K0 <$> f (fromRep k) a b

symmetricDifference :: TKey k => TMap k a -> TMap k a -> TMap k a
symmetricDifference = unionMaybeWith (\ _ _ -> Nothing)

intersection :: TKey k => TMap k a -> TMap k b -> TMap k a
intersection = intersectionWith const

intersectionWith :: TKey k => (a -> b -> c) -> TMap k a -> TMap k b -> TMap k c
intersectionWith = intersectionWithKey . const

intersectionWithKey :: TKey k => (k -> a -> b -> c) -> TMap k a -> TMap k b -> TMap k c
intersectionWithKey f = intersectionMaybeWithKey (\ k a b -> Just (f k a b))

intersectionMaybeWith :: TKey k => (a -> b -> Maybe c) -> TMap k a -> TMap k b -> TMap k c
intersectionMaybeWith = intersectionMaybeWithKey . const

intersectionMaybeWithKey :: TKey k => (k -> a -> b -> Maybe c) -> TMap k a -> TMap k b -> TMap k c
intersectionMaybeWithKey f (TMap m1) (TMap m2) = TMap (isectM sizeK0 f' m1 m2) where
	f' k (K0 a) (K0 b) = K0 <$> f (fromRep k) a b

difference, (\\) :: TKey k => TMap k a -> TMap k b -> TMap k a
difference = differenceWith (\ x _ -> Nothing)

(\\) = difference

differenceWith :: TKey k => (a -> b -> Maybe a) -> TMap k a -> TMap k b -> TMap k a
differenceWith = differenceWithKey . const

differenceWithKey :: TKey k => (k -> a -> b -> Maybe a) -> TMap k a -> TMap k b -> TMap k a
differenceWithKey f (TMap m1) (TMap m2) = TMap (diffM sizeK0 f' m1 m2) where
	f' k (K0 a) (K0 b) = K0 <$> f (fromRep k) a b

minView, maxView :: TKey k => TMap k a -> Maybe (a, TMap k a)
minView m = first snd <$> minViewWithKey m
maxView m = first snd <$> maxViewWithKey m

findMin, findMax :: TKey k => TMap k a -> (k, a)
findMin = maybe (error "empty map has no minimal element") fst . minViewWithKey
findMax = maybe (error "empty map has no maximal element") fst . maxViewWithKey

deleteMin, deleteMax :: TKey k => TMap k a -> TMap k a
deleteMin m = maybe m snd (minViewWithKey m)
deleteMax m = maybe m snd (maxViewWithKey m)

updateMin, updateMax :: TKey k => (a -> Maybe a) -> TMap k a -> TMap k a
updateMin = updateMinWithKey . const
updateMax = updateMaxWithKey . const

updateMinWithKey, updateMaxWithKey :: TKey k => (k -> a -> Maybe a) -> TMap k a -> TMap k a
updateMinWithKey f (TMap m) = TMap (alterMinM sizeK0 (\ k (K0 a) -> K0 <$> f (fromRep k) a) m)
updateMaxWithKey f (TMap m) = TMap (alterMaxM sizeK0 (\ k (K0 a) -> K0 <$> f (fromRep k) a) m)

deleteFindMin, deleteFindMax :: TKey k => TMap k a -> ((k, a), TMap k a)
deleteFindMin m = fromMaybe (error "Cannot return the minimal element of an empty map") (minViewWithKey m)
deleteFindMax m = fromMaybe (error "Cannot return the maximal element of an empty map") (maxViewWithKey m)

minViewWithKey, maxViewWithKey :: TKey k => TMap k a -> Maybe ((k, a), TMap k a)
minViewWithKey (TMap m) = do
	((k, K0 a), m') <- getFirst (extractMinM sizeK0 m)
	return ((fromRep k, a), TMap m')
maxViewWithKey (TMap m) = do
	((k, K0 a), m') <- getLast (extractMaxM sizeK0 m)
	return ((fromRep k, a), TMap m')

elems :: TKey k => TMap k a -> [a]
elems = fmap snd . assocs

keys :: TKey k => TMap k a -> [k]
keys = fmap fst . assocs

assocs :: TKey k => TMap k a -> [(k, a)]
assocs m = build (\ c n -> foldWithKey (curry c) n m)

mapEither :: TKey k => (a -> Either b c) -> TMap k a -> (TMap k b, TMap k c)
mapEither = mapEitherWithKey . const

mapEitherWithKey :: TKey k => (k -> a -> Either b c) -> TMap k a -> (TMap k b, TMap k c)
mapEitherWithKey f (TMap m) = case mapEitherM sizeK0 sizeK0 f' m of
	(mL, mR) -> (TMap mL, TMap mR) 
	where	f' k (K0 a) = case f (fromRep k) a of
			Left b	-> (Just (K0 b), Nothing)
			Right c	-> (Nothing, Just (K0 c))

mapMaybe :: TKey k => (a -> Maybe b) -> TMap k a -> TMap k b
mapMaybe = mapMaybeWithKey . const

mapMaybeWithKey :: TKey k => (k -> a -> Maybe b) -> TMap k a -> TMap k b
mapMaybeWithKey f (TMap m) = TMap (snd (mapEitherM sizeK0 sizeK0 f' m)) where
	f' k (K0 a) = (Nothing, K0 <$> f (fromRep k) a)

partition :: TKey k => (a -> Bool) -> TMap k a -> (TMap k a, TMap k a)
partition = partitionWithKey . const

partitionWithKey :: TKey k => (k -> a -> Bool) -> TMap k a -> (TMap k a, TMap k a)
partitionWithKey p = mapEitherWithKey (\ k a -> (if p k a then Left else Right) a)

filter :: TKey k => (a -> Bool) -> TMap k a -> TMap k a
filter = filterWithKey . const

filterWithKey :: TKey k => (k -> a -> Bool) -> TMap k a -> TMap k a
filterWithKey p = mapMaybeWithKey (\ k a -> if p k a then Just a else Nothing)

split :: TKey k => k -> TMap k a -> (TMap k a, TMap k a)
split k m = case splitLookup k m of
	(mL, _, mR) -> (mL, mR)

splitLookup :: TKey k => k -> TMap k a -> (TMap k a, Maybe a, TMap k a)
splitLookup k (TMap m) = case splitLookupM sizeK0 f (toRep k) m of
	(mL, x, mR) -> (TMap mL, x, TMap mR) 
	where	f (K0 x) = (Nothing, Just x, Nothing)

isSubmapOf :: (TKey k, Eq a) => TMap k a -> TMap k a -> Bool
isSubmapOf = isSubmapOfBy (==)

isSubmapOfBy :: TKey k => (a -> b -> Bool) -> TMap k a -> TMap k b -> Bool
isSubmapOfBy (<=) (TMap m1) (TMap m2) = isSubmapM (<<=) m1 m2 where
	K0 a <<= K0 b = a <= b

fromList, fromAscList :: TKey k => [(k, a)] -> TMap k a
fromList = fromListWith const
fromAscList = fromAscListWith const

fromListWith, fromAscListWith :: TKey k => (a -> a -> a) -> [(k, a)] -> TMap k a
fromListWith = fromListWithKey . const
fromAscListWith = fromAscListWithKey . const

fromListWithKey, fromAscListWithKey :: TKey k => (k -> a -> a -> a) -> [(k, a)] -> TMap k a
fromListWithKey f xs = TMap (fromListM sizeK0 (\ k (K0 a) (K0 b) -> K0 (f (fromRep k) a b)) [(toRep k, K0 a) | (k, a) <- xs])
fromAscListWithKey f xs = TMap (fromAscListM sizeK0 (\ k (K0 a) (K0 b) -> K0 (f (fromRep k) a b)) [(toRep k, K0 a) | (k, a) <- xs])

fromDistinctAscList :: TKey k => [(k, a)] -> TMap k a
fromDistinctAscList xs = TMap (fromDistAscListM sizeK0 [(toRep k, K0 a) | (k, a) <- xs])

size :: TKey k => TMap k a -> Int
size (TMap m) = sizeM sizeK0 m

member :: TKey k => k -> TMap k a -> Bool
member = isJust .: lookup

notMember :: TKey k => k -> TMap k a -> Bool
notMember = not .: member

showMap :: (TKey k, Show (TrieMap (Rep k) (K0 a) (Rep k))) => TMap k a -> String
showMap (TMap m) = show m