{-# LANGUAGE TypeFamilies, FlexibleContexts #-}

module Data.TrieMap (
	-- * Map type
	TKey,
	TMap,
	-- * Operators
	(!),
	(\\),
	-- * Query
	null,
	size,
	member,
	notMember,
	lookup,
	findWithDefault,
	-- * Construction
	empty,
-- 	showMap,
	singleton,
	-- ** Insertion
	insert,
	insertWith,
	insertWithKey,
	-- ** Delete/Update
	delete,
	adjust,
	adjustWithKey,
	update,
	updateWithKey,
	alter,
	-- * Combine
	-- ** Union
	union,
	unionWith,
	unionWithKey,
	unionMaybeWith,
	unionMaybeWithKey,
	-- ** Difference
	difference,
	differenceWith,
	differenceWithKey,
	-- ** Intersection
	intersection,
	intersectionWith,
	intersectionWithKey,
	intersectionMaybeWith,
	intersectionMaybeWithKey,
	-- * Traversal
	-- ** Map
	map,
	mapWithKey,
	mapKeys,
	mapKeysWith,
	mapKeysMonotonic,
	-- ** Traverse
	traverseWithKey,
	-- ** Fold
	fold,
	foldWithKey,
	foldrWithKey,
	foldlWithKey,
	-- * Conversion
	elems,
	keys,
	assocs,
	-- ** Lists
	fromList,
	fromListWith,
	fromListWithKey,
	-- ** Ordered lists
	fromAscList,
	fromAscListWith,
	fromAscListWithKey,
	fromDistinctAscList,
	-- * Filter
	filter,
	filterWithKey,
	partition,
	partitionWithKey,
	mapMaybe,
	mapMaybeWithKey,
	mapEither,
	mapEitherWithKey,
	split,
	splitLookup,
	-- * Submap
	isSubmapOf,
	isSubmapOfBy,
	-- * Indexed
	predecessor,
	lookupWithIndex,
	successor,
	neighborhood,
	lookupIndex,
	predecessorAt,
	lookupAt,
	successorAt,
	neighborhoodAt,	
	-- * Min/Max
	findMin,
	findMax,
	deleteMin,
	deleteMax,
	deleteFindMin,
	deleteFindMax,
	updateMin,
	updateMax,
	updateMinWithKey,
	updateMaxWithKey,
	minView,
	maxView,
	minViewWithKey,
	maxViewWithKey
	) where

import Data.TrieMap.Class
import Data.TrieMap.Class.Instances()
import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative
import Data.TrieMap.Rep
import Data.TrieMap.Rep.Instances
import Data.TrieMap.Modifiers
-- import Data.TrieMap.ReverseMap
import Data.TrieMap.Sized
import Data.TrieMap.CPair

import Control.Applicative hiding (empty)
import Control.Arrow
import Control.Monad
import Data.Maybe hiding (mapMaybe)
import Data.Monoid(Monoid(..), First(..), Last(..))
-- import Data.Foldable
-- import Data.Traversable

-- import Generics.MultiRec.Base
-- import Data.TrieMap.Regular.Base
-- import Data.TrieMap.Regular.Sized
import GHC.Exts (build)

import Prelude hiding (lookup, foldr, null, map, filter, reverse)

instance (Show k, Show a, TKey k) => Show (TMap k a) where
	show m = "fromList " ++ show (assocs m)

instance (Eq k, TKey k, Eq a) => Eq (TMap k a) where
	m1 == m2 = assocs m1 == assocs m2

instance (Ord k, TKey k, Ord a) => Ord (TMap k a) where
	m1 `compare` m2 = assocs m1 `compare` assocs m2

instance TKey k => Monoid (TMap k a) where
	mempty = empty
	mappend = union

-- newtype Elem a k = Elem {getElem :: a}
empty :: TKey k => TMap k a
empty = TMap emptyM

singleton :: TKey k => k -> a -> TMap k a
singleton k a = insert k a empty

null :: TKey k => TMap k a -> Bool
null (TMap m) = nullM m

lookup :: TKey k => k -> TMap k a -> Maybe a
lookup k (TMap m) = getElem <$> lookupM (toRep k) m

findWithDefault :: TKey k => a -> k -> TMap k a -> a
findWithDefault a = fromMaybe a .: lookup

(!) :: TKey k => TMap k a -> k -> a
m ! k = fromMaybe (error "Element not found") (lookup k m)

alter :: TKey k => (Maybe a -> Maybe a) -> k -> TMap k a -> TMap k a
alter f k (TMap m) = TMap (alterM elemSize (fmap Elem . f . fmap getElem) (toRep k) m)

-- | Projects information out of, and modifies or deletes, an individual association pair, 
-- alternating over all associations in the map.
-- 
-- > minViewWithKey == getFirst (extract (\ k a -> return ((k, a), Nothing)))
-- > updateMaxWithKey f m == maybe m snd (getLast (extract (\ k a -> return ((), f k a)) m))
-- 
-- In addition,
-- 
-- > getFirst (extract (\ k a -> if p k a then return ((k, a), Nothing) else mzero) m)
-- 
-- finds and removes the first association pair satisfying the predicate |p|.

extract :: (TKey k, MonadPlus m) => (k -> a -> m (x, Maybe a)) -> TMap k a -> m (x, TMap k a)
extract f m = unwrapMonad (extractA (WrapMonad .: f) m)

-- | Generalization of 'extract' for 'Alternative' functors.
extractA :: (TKey k, Alternative f) => (k -> a -> f (x, Maybe a)) -> TMap k a -> f (x, TMap k a)
extractA f (TMap m) = pairFromC <$> fmap TMap <$> extractM elemSize (\ k (Elem a) -> fmap (\ (x, y) -> x `cP` (Elem <$> y)) (f (fromRep k) a)) m

-- | Like 'extract', but does not modify the map.
about :: (TKey k, MonadPlus m) => (k -> a -> m x) -> TMap k a -> m x
about f = unwrapMonad . aboutA (WrapMonad .: f)

-- | Generalization of 'about' for 'Alternative' functors.
aboutA :: (TKey k, Alternative f) => (k -> a -> f x) -> TMap k a -> f x
aboutA f = fst <.> extractA (\ k a -> flip (,) Nothing <$> f k a)

insert :: TKey k => k -> a -> TMap k a -> TMap k a
insert = insertWith const

insertWith :: TKey k => (a -> a -> a) -> k -> a -> TMap k a -> TMap k a
insertWith = insertWithKey . const

insertWithKey :: TKey k => (k -> a -> a -> a) -> k -> a -> TMap k a -> TMap k a
insertWithKey f k a = alter f' k where
	f' = Just . maybe a (f k a)

delete :: TKey k => k -> TMap k a -> TMap k a
delete = alter (const Nothing)

adjust :: TKey k => (a -> a) -> k -> TMap k a -> TMap k a
adjust = adjustWithKey . const

adjustWithKey :: TKey k => (k -> a -> a) -> k -> TMap k a -> TMap k a
adjustWithKey f = updateWithKey (Just .: f)

update :: TKey k => (a -> Maybe a) -> k -> TMap k a -> TMap k a
update f = alter (>>= f)

updateWithKey :: TKey k => (k -> a -> Maybe a) -> k -> TMap k a -> TMap k a
updateWithKey f k = update (f k) k

fold :: TKey k => (a -> b -> b) -> b -> TMap k a -> b
fold = foldWithKey . const

foldWithKey, foldrWithKey :: TKey k => (k -> a -> b -> b) -> b -> TMap k a -> b
foldWithKey f z (TMap m) = foldWithKeyM (\ k (Elem a) -> f (fromRep k) a) m z
foldrWithKey = foldWithKey

foldlWithKey :: TKey k => (b -> k -> a -> b) -> b -> TMap k a -> b
foldlWithKey f z (TMap m) = foldlWithKeyM (\ k z (Elem a) -> f z (fromRep k) a) m z

traverseWithKey :: (TKey k, Applicative f) => (k -> a -> f b) -> TMap k a -> f (TMap k b)
traverseWithKey f (TMap m) = TMap <$> traverseWithKeyM elemSize (\ k (Elem a) -> Elem <$> f (fromRep k) a) m

map :: TKey k => (a -> b) -> TMap k a -> TMap k b
map = fmap

mapWithKey :: TKey k => (k -> a -> b) -> TMap k a -> TMap k b
mapWithKey f (TMap m) = TMap (mapWithKeyM elemSize (\ k (Elem a) -> Elem (f (fromRep k) a)) m)

mapKeys :: (TKey k, TKey k') => (k -> k') -> TMap k a -> TMap k' a
mapKeys f m = fromList [(f k, a) | (k, a) <- assocs m]

mapKeysWith :: (TKey k, TKey k') => (a -> a -> a) -> (k -> k') -> TMap k a -> TMap k' a
mapKeysWith g f m = fromListWith g [(f k, a) | (k, a) <- assocs m]

mapKeysMonotonic :: (TKey k, TKey k') => (k -> k') -> TMap k a -> TMap k' a
mapKeysMonotonic f m = fromDistinctAscList [(f k, a) | (k, a) <- assocs m]

union :: TKey k => TMap k a -> TMap k a -> TMap k a
union = unionWith const

unionWith :: TKey k => (a -> a -> a) -> TMap k a -> TMap k a -> TMap k a
unionWith = unionWithKey . const

unionWithKey :: TKey k => (k -> a -> a -> a) -> TMap k a -> TMap k a -> TMap k a
unionWithKey f = unionMaybeWithKey (\ k a b -> Just (f k a b))

unionMaybeWith :: TKey k => (a -> a -> Maybe a) -> TMap k a -> TMap k a -> TMap k a
unionMaybeWith = unionMaybeWithKey . const

unionMaybeWithKey :: TKey k => (k -> a -> a -> Maybe a) -> TMap k a -> TMap k a -> TMap k a
unionMaybeWithKey f (TMap m1) (TMap m2) = TMap (unionM elemSize f' m1 m2) where
	f' k (Elem a) (Elem b) = Elem <$> f (fromRep k) a b

symmetricDifference :: TKey k => TMap k a -> TMap k a -> TMap k a
symmetricDifference = unionMaybeWith (\ _ _ -> Nothing)

intersection :: TKey k => TMap k a -> TMap k b -> TMap k a
intersection = intersectionWith const

intersectionWith :: TKey k => (a -> b -> c) -> TMap k a -> TMap k b -> TMap k c
intersectionWith = intersectionWithKey . const

intersectionWithKey :: TKey k => (k -> a -> b -> c) -> TMap k a -> TMap k b -> TMap k c
intersectionWithKey f = intersectionMaybeWithKey (\ k a b -> Just (f k a b))

intersectionMaybeWith :: TKey k => (a -> b -> Maybe c) -> TMap k a -> TMap k b -> TMap k c
intersectionMaybeWith = intersectionMaybeWithKey . const

intersectionMaybeWithKey :: TKey k => (k -> a -> b -> Maybe c) -> TMap k a -> TMap k b -> TMap k c
intersectionMaybeWithKey f (TMap m1) (TMap m2) = TMap (isectM elemSize f' m1 m2) where
	f' k (Elem a) (Elem b) = Elem <$> f (fromRep k) a b

difference, (\\) :: TKey k => TMap k a -> TMap k b -> TMap k a
difference = differenceWith (\ x _ -> Nothing)

(\\) = difference

differenceWith :: TKey k => (a -> b -> Maybe a) -> TMap k a -> TMap k b -> TMap k a
differenceWith = differenceWithKey . const

differenceWithKey :: TKey k => (k -> a -> b -> Maybe a) -> TMap k a -> TMap k b -> TMap k a
differenceWithKey f (TMap m1) (TMap m2) = TMap (diffM elemSize f' m1 m2) where
	f' k (Elem a) (Elem b) = Elem <$> f (fromRep k) a b

minView, maxView :: TKey k => TMap k a -> Maybe (a, TMap k a)
minView m = first snd <$> minViewWithKey m
maxView m = first snd <$> maxViewWithKey m

findMin, findMax :: TKey k => TMap k a -> (k, a)
findMin = maybe (error "empty map has no minimal element") fst . minViewWithKey
findMax = maybe (error "empty map has no maximal element") fst . maxViewWithKey

deleteMin, deleteMax :: TKey k => TMap k a -> TMap k a
deleteMin m = maybe m snd (minViewWithKey m)
deleteMax m = maybe m snd (maxViewWithKey m)

updateMin, updateMax :: TKey k => (a -> Maybe a) -> TMap k a -> TMap k a
updateMin = updateMinWithKey . const
updateMax = updateMaxWithKey . const

updateMinWithKey, updateMaxWithKey :: TKey k => (k -> a -> Maybe a) -> TMap k a -> TMap k a
updateMinWithKey f m = maybe m snd (getFirst (extract (\ k a -> return ((), f k a)) m))
updateMaxWithKey f m = maybe m snd (getLast (extract (\ k a -> return ((), f k a)) m))

deleteFindMin, deleteFindMax :: TKey k => TMap k a -> ((k, a), TMap k a)
deleteFindMin m = fromMaybe (error "Cannot return the minimal element of an empty map") (minViewWithKey m)
deleteFindMax m = fromMaybe (error "Cannot return the maximal element of an empty map") (maxViewWithKey m)

minViewWithKey, maxViewWithKey :: TKey k => TMap k a -> Maybe ((k, a), TMap k a)
minViewWithKey = getFirst . extract (\ k a -> return ((k, a), Nothing))
maxViewWithKey = getLast . extract (\ k a -> return ((k, a), Nothing))

elems :: TKey k => TMap k a -> [a]
elems = fmap snd . assocs

keys :: TKey k => TMap k a -> [k]
keys = fmap fst . assocs

assocs :: TKey k => TMap k a -> [(k, a)]
assocs m = build (\ c n -> foldWithKey (curry c) n m)

mapEither :: TKey k => (a -> Either b c) -> TMap k a -> (TMap k b, TMap k c)
mapEither = mapEitherWithKey . const

mapEitherWithKey :: TKey k => (k -> a -> Either b c) -> TMap k a -> (TMap k b, TMap k c)
mapEitherWithKey f (TMap m) = case mapEitherM elemSize elemSize f' m of
	(mL, mR) -> (TMap mL, TMap mR) 
	where	f' k (Elem a) = case f (fromRep k) a of
			Left b	-> (Just (Elem b), Nothing)
			Right c	-> (Nothing, Just (Elem c))

mapMaybe :: TKey k => (a -> Maybe b) -> TMap k a -> TMap k b
mapMaybe = mapMaybeWithKey . const

mapMaybeWithKey :: TKey k => (k -> a -> Maybe b) -> TMap k a -> TMap k b
mapMaybeWithKey f (TMap m) = TMap (snd (mapEitherM elemSize elemSize f' m)) where
	f' k (Elem a) = (Nothing, Elem <$> f (fromRep k) a)

partition :: TKey k => (a -> Bool) -> TMap k a -> (TMap k a, TMap k a)
partition = partitionWithKey . const

partitionWithKey :: TKey k => (k -> a -> Bool) -> TMap k a -> (TMap k a, TMap k a)
partitionWithKey p = mapEitherWithKey (\ k a -> (if p k a then Left else Right) a)

filter :: TKey k => (a -> Bool) -> TMap k a -> TMap k a
filter = filterWithKey . const

filterWithKey :: TKey k => (k -> a -> Bool) -> TMap k a -> TMap k a
filterWithKey p = mapMaybeWithKey (\ k a -> if p k a then Just a else Nothing)

split :: TKey k => k -> TMap k a -> (TMap k a, TMap k a)
split k m = case splitLookup k m of
	(mL, _, mR) -> (mL, mR)

splitLookup :: TKey k => k -> TMap k a -> (TMap k a, Maybe a, TMap k a)
splitLookup k (TMap m) = case splitLookupM elemSize f (toRep k) m of
	(mL, x, mR) -> (TMap mL, x, TMap mR) 
	where	f (Elem x) = (Nothing, Just x, Nothing)

isSubmapOf :: (TKey k, Eq a) => TMap k a -> TMap k a -> Bool
isSubmapOf = isSubmapOfBy (==)

isSubmapOfBy :: TKey k => (a -> b -> Bool) -> TMap k a -> TMap k b -> Bool
isSubmapOfBy (<=) (TMap m1) (TMap m2) = isSubmapM (<<=) m1 m2 where
	Elem a <<= Elem b = a <= b

fromList, fromAscList :: TKey k => [(k, a)] -> TMap k a
fromList = fromListWith const
fromAscList = fromAscListWith const

fromListWith, fromAscListWith :: TKey k => (a -> a -> a) -> [(k, a)] -> TMap k a
fromListWith = fromListWithKey . const
fromAscListWith = fromAscListWithKey . const

fromListWithKey, fromAscListWithKey :: TKey k => (k -> a -> a -> a) -> [(k, a)] -> TMap k a
fromListWithKey f xs = TMap (fromListM elemSize (\ k (Elem a) (Elem b) -> Elem (f (fromRep k) a b)) [(toRep k, Elem a) | (k, a) <- xs])
fromAscListWithKey f xs = TMap (fromAscListM elemSize (\ k (Elem a) (Elem b) -> Elem (f (fromRep k) a b)) [(toRep k, Elem a) | (k, a) <- xs])

fromDistinctAscList :: TKey k => [(k, a)] -> TMap k a
fromDistinctAscList xs = TMap (fromDistAscListM elemSize [(toRep k, Elem a) | (k, a) <- xs])

size :: TKey k => TMap k a -> Int
size (TMap m) = sizeM elemSize m

member :: TKey k => k -> TMap k a -> Bool
member = isJust .: lookup

notMember :: TKey k => k -> TMap k a -> Bool
notMember = not .: member

-- showMap :: (TKey k, Show (TrieMap (Rep k) (Elem a) (Rep k))) => TMap k a -> String
-- showMap (TMap m) = show m

-- | @'predecessor' k a@ returns the index, key, and value of the immediate predecessor of @k@ in the map.  
-- The predecessor is the element with the largest key @< k@.
predecessor :: TKey k => k -> TMap k a -> Maybe (Int, k, a)
predecessor k m = fst3 (neighborhood k m)

lookupIndex :: TKey k => k -> TMap k a -> Maybe Int
lookupIndex k m = fst3 <$> lookupWithIndex k m 

fst3 (a, b, c) = a
snd3 (a, b, c) = b
thd3 (a, b, c) = c

findIndex :: TKey k => k -> TMap k a -> Int
k `findIndex`  m = fromMaybe (error "element is not in the map") (k `lookupIndex` m)

lookupWithIndex :: TKey k => k -> TMap k a -> Maybe (Int, k, a)
lookupWithIndex k m = snd3 (neighborhood k m)

successor :: TKey k => k -> TMap k a -> Maybe (Int, k, a)
successor k m = thd3 (neighborhood k m)

neighborhood :: TKey k => k -> TMap k a -> (Maybe (Int, k, a), Maybe (Int, k, a), Maybe (Int, k, a))
neighborhood k (TMap m) = case lookupIxM elemSize (toRep k) m of
		(pr, x, su) -> (fix <$> getLast pr, fix <$> x, fix <$> getFirst su)
	where	fix (Asc i k (Elem a)) = (i, fromRep k, a)

predecessorAt :: TKey k => Int -> TMap k a -> Maybe (Int, k, a)
predecessorAt k m = fst3 (neighborhoodAt k m)

lookupAt :: TKey k => Int -> TMap k a -> Maybe (Int, k, a)
lookupAt k m = snd3 (neighborhoodAt k m)

successorAt :: TKey k => Int -> TMap k a -> Maybe (Int, k, a)
successorAt k m = thd3 (neighborhoodAt k m)

neighborhoodAt :: TKey k => Int -> TMap k a -> (Maybe (Int, k, a), Maybe (Int, k, a), Maybe (Int, k, a))
neighborhoodAt i (TMap m) = case assocAtM elemSize i m of
		(pr, x, su) -> (fix <$> getLast pr, fix <$> x, fix <$> getFirst su)
	where	fix (Asc i k (Elem a)) = (i, fromRep k, a)

keysSet :: TKey k => TMap k a -> TSet k
keysSet = TSet . map (const ())

-- reverseMap :: TKey k => TMap k a -> TMap (Rev k) a
-- reverseMap (TMap m) = TMap (reverse m)

-- unReverseMap :: TKey k => TMap (Rev k) a -> TMap k a
-- unReverseMap (TMap m) = TMap (unreverse m)