{-# LANGUAGE UndecidableInstances, TemplateHaskell, FlexibleContexts, TypeOperators, Rank2Types, PatternGuards, MultiParamTypeClasses, TypeFamilies #-}

module Data.TrieMap.OrdMap () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Sized
-- import Data.TrieMap.Applicative
import Data.TrieMap.Modifiers
import Data.TrieMap.CPair
-- import Data.TrieMap.MultiRec.Base
-- import Data.TrieMap.Rep
-- import Data.TrieMap.Rep.TH

import Control.Applicative (Applicative(..), Alternative(..), (<$>))
-- import Control.Arrow
import Control.Monad hiding (join)

-- import Data.Monoid
-- import Data.Maybe
-- import Data.Map
-- import qualified Data.Map as Map
-- import Data.Traversable

import Prelude hiding (lookup)

data OrdMap k a = Tip 
              | Bin {-# UNPACK #-} !Int k (a) !(OrdMap k a) !(OrdMap k a) 

type instance TrieMap (Ordered k) = OrdMap k

-- type instance RepT (OrdMap k) = FamT KeyFam (HFix (U :+: (K Int :*: K k :*: X :*: A0 :*: A0)))
-- type instance Rep (OrdMap k a) = RepT (OrdMap k) (Rep a)

-- -- $(genRepT [d|
--    instance ReprT (OrdMap k) where
-- 	toRepT = FamT . toFix where
-- 		toFix Tip = HIn (L U)
-- 		toFix (Bin s kx x l r) = HIn (R (K s :*: K kx :*: X x :*: A0 (toFix l) :*: A0 (toFix r)))
-- 	fromRepT (FamT x) = fromFix x where
-- 		fromFix (HIn L{}) = Tip
-- 		fromFix (HIn (R (K s :*: K kx :*: X x :*: A0 l :*: A0 r)))
-- 			= Bin s kx x (fromFix l) (fromFix r) |])

instance Ord k => TrieKey (Ordered k) (OrdMap k) where
	emptyM = Tip
	nullM Tip = True
	nullM _ = False
	sizeM _ = size
	lookupM (Ord k) = lookup k
	lookupIxM s (Ord k) = onKey Ord . lookupIx s 0 k
	assocAtM s i = onKey Ord . assocAt s 0 i
-- 	updateAtM s r f = updateAt s 0 r (\ i -> f i . Ord)
	alterM s f (Ord k) = alter s f k
	alterLookupM s f (Ord k) = alterLookup s f k
	traverseWithKeyM s f = traverseWithKey s (f . Ord)
	foldWithKeyM f = foldrWithKey (f . Ord)
	foldlWithKeyM f = foldlWithKey (f . Ord)
	mapEitherM s1 s2 f = mapEither s1 s2 (f . Ord)
	extractM s f m = extract s (f . Ord) m
-- 	extractMinM _ _ Tip = mzero
-- 	extractMinM s f m = return (deleteFindMin s (f . Ord) m)
-- 	extractMaxM _ _ Tip = mzero
-- 	extractMaxM s f m = return (deleteFindMax s (f . Ord) m)
-- 	alterMinM s f = updateMin s (f . Ord)
-- 	alterMaxM s f = updateMax s (f . Ord)
	splitLookupM s f (Ord k) = splitLookup s f k
	isSubmapM = isSubmap
	fromAscListM s f xs = fromAscList s (f . Ord) [(k, a) | (Ord k, a) <- xs]
	fromDistAscListM s xs = fromDistinctAscList s [(k, a) | (Ord k, a) <- xs]
	unionM s f m1 m2 = case (m1, m2) of
		(Tip, _) -> m2
		(_, Tip) -> m1
		_	 -> hedgeUnionWithKey s (f . Ord) (const LT) (const GT) m1 m2
	isectM s f = isect s (f . Ord)
	diffM s f m1 m2 = case (m1, m2) of
		(Tip, _) -> Tip
		(_, Tip) -> m1
		_	 -> hedgeDiffWithKey s (f . Ord) (const LT) (const GT) m1 m2

lookup :: Ord k => k -> OrdMap k a -> Maybe (a)
lookup k Tip = Nothing
lookup k (Bin _ k' v l r) = case compare k k' of
	LT	-> lookup k l
	EQ	-> Just v
	GT	-> lookup k r

lookupIx :: Ord k => Sized a -> Int -> k -> OrdMap k a -> IndexPos k a
lookupIx _ i _ _ | i `seq` False = undefined
lookupIx _ _ _ Tip = (mzero, mzero, mzero)
lookupIx s i k (Bin sz kx x l r) = case compare k kx of
	LT	-> case lookupIx s i k l of
		(lb, ans, ub) -> (lb, ans, ub <|> return (Asc (i + size l) kx x))
	EQ	-> (extractMax (\ k v -> Asc (i + size l - s v) k v) l,
			return (Asc (i + size l) kx x),
		    extractMin (Asc (i + sz - size r)) r)
	GT	-> case lookupIx s (i + sz - size r) k r of
		(lb, ans, ub) -> (return (Asc (i + size l) kx x) <|> lb, ans, ub)
	where	extractMin f Tip = mzero
		extractMin f b = return (fst $ deleteFindMin s (\ k x -> (f k x, Just x)) b)
		extractMax f Tip = mzero
		extractMax f b = return (fst $ deleteFindMax s (\ k x -> (f k x, Just x)) b)

assocAt :: Sized a -> Int -> Int -> OrdMap k a -> IndexPos k a
assocAt _ i0 i _ | i0 `seq` i `seq` False = undefined
assocAt _ _ _ Tip = (mzero, mzero, mzero)
assocAt s i0 i (Bin sz k a l r)
	| i < sL, (lb, ans, ub) <- assocAt s i0 i l
			= (lb, ans, ub <|> return (Asc (i0 + size l) k a))
	| i < sK	= (extractMax (\ k v -> Asc (i0 + sL - s v) k v) l,
				return (Asc (i0 + sL) k a),
			   extractMin (Asc sK) r)
	| (lb, ans, ub) <- assocAt s (i0 + sK) (i - sK) r
			= (return (Asc (i0 + sL) k a) <|> lb, ans, ub)
	where	sL = size l
		sK = sz - size r
		extractMin f Tip = mzero
		extractMin f b = return (fst $ deleteFindMin s (\ k x -> (f k x, Just x)) b)
		extractMax f Tip = mzero
		extractMax f b = return (fst $ deleteFindMax s (\ k x -> (f k x, Just x)) b)

updateAt :: Sized a -> Int -> Round -> (Int -> k -> a -> Maybe (a)) -> Int -> OrdMap k a -> OrdMap k a
updateAt _ i0 _ _ i _ | i0 `seq` i `seq` False = undefined
updateAt _ _ _ _ _ Tip = Tip
updateAt s i0 True f i (Bin sz k a l r)
	| i < sL	= balance s k a (updateAt s i0 True f i l) r
	| i < sK	= case f (i0 + sL) k a of
		Nothing	-> glue s l r
		Just a'	-> bin s k a' l r
	| otherwise	= balance s k a l (updateAt s (i0 + sK) True f (i - sK) r)
	where	sL = size l
		sK = sz - size r 
updateAt s i0 False f i (Bin sz k a l r)
	| i < maxIxL	= balance s k a (updateAt s i0 False f i l) r
	| i <= sL	= case f (i0 + sL) k a of
		Nothing	-> glue s l r
		Just a' -> bin s k a' l r
	| otherwise	= balance s k a l (updateAt s (i0 + sK) False f (i - sK) r)
	where	sL = size l
		maxIxL = case l of	Tip	-> 0
					_ 	-> fst (deleteFindMax s (\ _ a -> (size l - s a, Just a)) l)
		sK = sz - size r

alter :: Ord k => Sized a -> (Maybe (a) -> Maybe (a)) -> k -> OrdMap k a -> OrdMap k a
alter s f k Tip = case f Nothing of
	Nothing	-> Tip
	Just x	-> singleton s k x
alter s f k (Bin _ kx x l r) = case compare k kx of
	LT	-> balance s kx x (alter s f k l) r
	EQ	-> case f (Just x) of
		Nothing	-> glue s l r
		Just x'	-> balance s k x' l r
	GT	-> balance s kx x l (alter s f k r)

alterLookup :: Ord k => Sized a -> (Maybe a -> CPair z (Maybe a)) -> k -> OrdMap k a -> CPair z (OrdMap k a)
alterLookup s f k Tip = maybe Tip (singleton s k) <$> f Nothing
alterLookup s f k (Bin _ kx x l r) = case compare k kx of
	LT -> fmap (\ l' -> balance s kx x l' r) (alterLookup s f k l)
	EQ -> maybe (glue s l r) (\ x' -> balance s k x' l r) <$> f (Just x)
	GT -> fmap (\ r' -> balance s kx x l r') (alterLookup s f k r)

singleton :: Sized a -> k -> a -> OrdMap k a
singleton s k a = Bin (s a) k a Tip Tip

traverseWithKey :: Applicative f => Sized b -> (k -> a -> f (b)) -> OrdMap k a -> f (OrdMap k b)
traverseWithKey s f Tip = pure Tip
traverseWithKey s f (Bin _ k a l r) = balance s k <$> f k a <*> traverseWithKey s f l <*> traverseWithKey s f r

foldrWithKey :: (k -> a -> b -> b) -> OrdMap k a -> b -> b
foldrWithKey f Tip = id
foldrWithKey f (Bin _ k a l r) = foldrWithKey f l . f k a . foldrWithKey f r

foldlWithKey :: (k -> b -> a -> b) -> OrdMap k a -> b -> b
foldlWithKey f Tip = id
foldlWithKey f (Bin _ k a l r) = foldlWithKey f r . flip (f k) a . foldlWithKey f l

mapEither :: Ord k => Sized b -> Sized c -> EitherMap k (a) (b) (c) ->
	OrdMap k a -> (OrdMap k b, OrdMap k c)
mapEither s1 s2 f m = case m of
	Tip	-> (Tip, Tip)
	Bin _ k a l r -> case (f k a, mapEither s1 s2 f l, mapEither s1 s2 f r) of
		((aL, aR), (lL, lR), (rL, rR)) ->
			(joinMaybe s1 k aL lL rL, joinMaybe s2 k aR lR rR)

updateMin :: Ord k => Sized a -> (k -> a -> Maybe (a)) -> OrdMap k a -> OrdMap k a
updateMin s f m = case m of
	Tip	-> Tip
	Bin _ k a Tip r -> case f k a of
		Nothing -> r
		Just a'	-> insertMin s k a' r
	Bin _ k a l r	-> balance s k a (updateMin s f l) r

updateMax :: Ord k => Sized a -> (k -> a -> Maybe (a)) -> OrdMap k a -> OrdMap k a
updateMax s f m = case m of
	Tip	-> Tip
	Bin _ k a l Tip	-> case f k a of
		Nothing	-> l
		Just a'	-> insertMax s k a' l
	Bin _ k a l r	-> balance s k a l (updateMax s f r)

splitLookup :: Ord k => Sized a -> SplitMap (a) x -> k -> OrdMap k a -> (OrdMap k a, Maybe x, OrdMap k a)
splitLookup s f k m = case m of
	Tip	-> (Tip, Nothing, Tip)
	Bin _ kx x l r -> case compare k kx of
		LT	-> case splitLookup s f k l of
			(lL, ans, lR) -> (lL, ans, join s kx x lR r)
		EQ	-> case f x of
			(xL, ans, xR) -> (maybe l (\ xL -> insertMax s kx xL l) xL, ans,
						maybe r (\ xR -> insertMin s kx xR r) xR)
		GT	-> case splitLookup s f k r of
			(rL, ans, rR) -> (join s kx x l rL, ans, rR)

isSubmap :: Ord k => LEq (a) (b) -> LEq (OrdMap k a) (OrdMap k b)
isSubmap (<=) Tip _ = True
isSubmap (<=) _ Tip = False
isSubmap (<=) (Bin _ kx x l r) t = case found of
	Nothing	-> False
	Just y	-> x <= y && isSubmap (<=) l lt && isSubmap (<=) r gt
	where	(lt, found, gt) = splitLookup (const 1) (\ x -> (Nothing, Just x, Nothing)) kx t

fromAscList :: Eq k => Sized a -> (k -> a -> a -> a) -> [(k, a)] -> OrdMap k a
fromAscList s f xs = fromDistinctAscList s (combineEq xs) where
	combineEq (x:xs) = combineEq' x xs
	combineEq [] = []
	
	combineEq' z [] = [z]
	combineEq' z@(kz, zz) (x@(kx, xx):xs)
		| kz == kx	= combineEq' (kx, f kx xx zz) xs
		| otherwise	= (kz,zz):combineEq' x xs

fromDistinctAscList :: Sized a -> [(k, a)] -> OrdMap k a
fromDistinctAscList s xs = build const (length xs) xs
  where
    -- 1) use continutations so that we use heap space instead of stack space.
    -- 2) special case for n==5 to build bushier trees. 
    build c 0 xs'  = c Tip xs'
    build c 5 xs'  = case xs' of
                       ((k1,x1):(k2,x2):(k3,x3):(k4,x4):(k5,x5):xx) 
                            -> c (bin s k4 x4 (bin s k2 x2 (singleton s k1 x1) (singleton s k3 x3)) (singleton s k5 x5)) xx
                       _ -> error "fromDistinctAscList build"
    build c n xs'  = seq nr $ build (buildR nr c) nl xs'
                   where
                     nl = n `div` 2
                     nr = n - nl - 1

    buildR n c l ((k,x):ys) = build (buildB l k x c) n ys
    buildR _ _ _ []         = error "fromDistinctAscList buildR []"
    buildB l k x c r zs     = c (bin s k x l r) zs

hedgeUnionWithKey :: Ord k
                  => Sized a -> (k -> a -> a -> Maybe (a))
                  -> (k -> Ordering) -> (k -> Ordering)
                  -> OrdMap k a -> OrdMap k a -> OrdMap k a
hedgeUnionWithKey _ _ _     _     t1 Tip
  = t1
hedgeUnionWithKey s _ cmplo cmphi Tip (Bin _ kx x l r)
  = join s kx x (filterGt s cmplo l) (filterLt s cmphi r)
hedgeUnionWithKey s f cmplo cmphi (Bin _ kx x l r) t2
  = joinMaybe s kx newx (hedgeUnionWithKey s f cmplo cmpkx l lt) 
                 (hedgeUnionWithKey s f cmpkx cmphi r gt)
  where
    cmpkx k     = compare kx k
    lt          = trim cmplo cmpkx t2
    (found,gt)  = trimLookupLo kx cmphi t2
    newx        = case found of
                    Nothing -> Just x
                    Just (_,y) -> f kx x y

filterGt :: Ord k => Sized a -> (k -> Ordering) -> OrdMap k a -> OrdMap k a
filterGt _ _   Tip = Tip
filterGt s cmp (Bin _ kx x l r)
  = case cmp kx of
      LT -> join s kx x (filterGt s cmp l) r
      GT -> filterGt s cmp r
      EQ -> r
      
filterLt :: Ord k => Sized a -> (k -> Ordering) -> OrdMap k a -> OrdMap k a
filterLt _ _   Tip = Tip
filterLt s cmp (Bin _ kx x l r)
  = case cmp kx of
      LT -> filterLt s cmp l
      GT -> join s kx x l (filterLt s cmp r)
      EQ -> l

trim :: (k -> Ordering) -> (k -> Ordering) -> OrdMap k a -> OrdMap k a
trim _     _     Tip = Tip
trim cmplo cmphi t@(Bin _ kx _ l r)
  = case cmplo kx of
      LT -> case cmphi kx of
              GT -> t
              _  -> trim cmplo cmphi l
      _  -> trim cmplo cmphi r
              
trimLookupLo :: Ord k => k -> (k -> Ordering) -> OrdMap k a -> (Maybe (k,a), OrdMap k a)
trimLookupLo _  _     Tip = (Nothing,Tip)
trimLookupLo lo cmphi t@(Bin _ kx x l r)
  = case compare lo kx of
      LT -> case cmphi kx of
              GT -> (((,) lo) <$> lookup lo t, t)
              _  -> trimLookupLo lo cmphi l
      GT -> trimLookupLo lo cmphi r
      EQ -> (Just (kx,x),trim (compare lo) cmphi r)

isect :: Ord k => Sized c -> IsectFunc k (a) (b) (c) -> OrdMap k a -> OrdMap k b -> OrdMap k c
isect s f Tip _ = Tip
isect s f _ Tip = Tip
isect s f t1@(Bin _ k1 x1 l1 r1) t2@(Bin _ k2 x2 l2 r2) =
	let	(lt, found, gt) = splitLookup (const 1) (\ x -> (Nothing, Just x, Nothing)) k2 t1
		tl		= isect s f lt l2
		tr		= isect s f gt r2
	 in joinMaybe s k2 (found >>= \ x1' -> f k2 x1' x2) tl tr


hedgeDiffWithKey :: Ord k
                 => Sized a -> (k -> a -> b -> Maybe (a))
                 -> (k -> Ordering) -> (k -> Ordering)
                 -> OrdMap k a -> OrdMap k b -> OrdMap k a
hedgeDiffWithKey _ _ _     _     Tip _
  = Tip
hedgeDiffWithKey s _ cmplo cmphi (Bin _ kx x l r) Tip
  = join s kx x (filterGt s cmplo l) (filterLt s cmphi r)
hedgeDiffWithKey s f cmplo cmphi t (Bin _ kx x l r) 
  = case found of
      Nothing -> merge s tl tr
      Just (ky,y) -> 
          case f ky y x of
            Nothing -> merge s tl tr
            Just z  -> join s ky z tl tr
  where
    cmpkx k     = compare kx k   
    lt          = trim cmplo cmpkx t
    (found,gt)  = trimLookupLo kx cmphi t
    tl          = hedgeDiffWithKey s f cmplo cmpkx lt l
    tr          = hedgeDiffWithKey s f cmpkx cmphi gt r

joinMaybe :: Ord k => Sized a -> k -> Maybe (a) -> OrdMap k a -> OrdMap k a -> OrdMap k a
joinMaybe s kx = maybe (merge s) (join s kx)

join :: Ord k => Sized a -> k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
join s kx x Tip r  = insertMin s kx x r
join s kx x l Tip  = insertMax s kx x l
join s kx x l@(Bin sizeL ky y ly ry) r@(Bin sizeR kz z lz rz)
  | delta*sizeL <= sizeR  = balance s kz z (join s kx x l lz) rz
  | delta*sizeR <= sizeL  = balance s ky y ly (join s kx x ry r)
  | otherwise             = bin s kx x l r


-- insertMin and insertMax don't perform potentially expensive comparisons.
insertMax,insertMin :: Sized a -> k -> a -> OrdMap k a -> OrdMap k a
insertMax s kx x t
  = case t of
      Tip -> singleton s kx x
      Bin _ ky y l r
          -> balance s ky y l (insertMax s kx x r)
             
insertMin s kx x t
  = case t of
      Tip -> singleton s kx x
      Bin _ ky y l r
          -> balance s ky y (insertMin s kx x l) r
             
{--------------------------------------------------------------------
  [merge l r]: merges two trees.
--------------------------------------------------------------------}
merge :: Sized a -> OrdMap k a -> OrdMap k a -> OrdMap k a
merge _ Tip r   = r
merge _ l Tip   = l
merge s l@(Bin sizeL kx x lx rx) r@(Bin sizeR ky y ly ry)
  | delta*sizeL <= sizeR = balance s ky y (merge s l ly) ry
  | delta*sizeR <= sizeL = balance s kx x lx (merge s rx r)
  | otherwise            = glue s l r

{--------------------------------------------------------------------
  [glue l r]: glues two trees together.
  Assumes that [l] and [r] are already balanced with respect to each other.
--------------------------------------------------------------------}
glue :: Sized a -> OrdMap k a -> OrdMap k a -> OrdMap k a
glue _ Tip r = r
glue _ l Tip = l
glue s l r   
  | size l > size r = let (f,l') = deleteFindMax s (\ k a -> (balance s k a, Nothing)) l in f l' r
  | otherwise       = let (f,r') = deleteFindMin s (\ k a -> (balance s k a, Nothing)) r in f l r'

extract :: Alternative t => Sized a -> (k -> a -> t (CPair z (Maybe a))) -> OrdMap k a -> t (CPair z (OrdMap k a))
extract s f t = case t of
	Bin _ k x l r -> 
		fmap (\ l' -> balance s k x l' r) <$> extract s f l <|>
		fmap (maybe (glue s l r) (\ x' -> balance s k x' l r))  <$> f k x <|>
		fmap (balance s k x l) <$> extract s f r

deleteFindMin :: Sized a -> (k -> a -> (x, Maybe a)) -> OrdMap k a -> (x, OrdMap k a)
deleteFindMin s f t 
  = case t of
      Bin _ k x Tip r -> let (ans, x') = f k x in (ans, maybe r (\ y' -> bin s k y' Tip r) x')
      Bin _ k x l r   -> let (km,l') = deleteFindMin s f l in (km,balance s k x l' r)
      Tip             -> (error "Map.deleteFindMin: can not return the minimal element of an empty map", Tip)

deleteFindMax :: Sized a -> (k -> a -> (x, Maybe a)) -> OrdMap k a -> (x, OrdMap k a)
deleteFindMax s f t
  = case t of
      Bin _ k x l Tip -> let (ans, x') = f k x in (ans, maybe l (\ y -> bin s k y l Tip) x')
      Bin _ k x l r   -> let (km,r') = deleteFindMax s f r in (km,balance s k x l r')
      Tip             -> (error "Map.deleteFindMax: can not return the maximal element of an empty map", Tip)

delta,ratio :: Int
delta = 5
ratio = 2

size :: OrdMap k a -> Int
size Tip = 0
size (Bin s _ _ _ _) = s

balance :: Sized a -> k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
balance s k x l r
  | sizeL + sizeR <= 1    = Bin sizeX k x l r
  | sizeR >= delta*sizeL  = rotateL s k x l r
  | sizeL >= delta*sizeR  = rotateR s k x l r
  | otherwise             = Bin sizeX k x l r
  where
    sizeL = size l
    sizeR = size r
    sizeX = sizeL + sizeR + s x

-- rotate
rotateL :: Sized a -> k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
rotateL s k x l r@(Bin _ _ _ ly ry)
  | size ly < ratio*size ry = singleL s k x l r
  | otherwise               = doubleL s k x l r
rotateL _ _ _ _ Tip = error "rotateL Tip"

rotateR :: Sized a -> k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
rotateR s k x l@(Bin _ _ _ ly ry) r
  | size ry < ratio*size ly = singleR s k x l r
  | otherwise               = doubleR s k x l r
rotateR _ _ _ Tip _ = error "rotateR Tip"

-- basic rotations
singleL, singleR :: Sized a -> k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
singleL s k1 x1 t1 (Bin _ k2 x2 t2 t3)  = bin s k2 x2 (bin s k1 x1 t1 t2) t3
singleL s k1 x1 t1 Tip = bin s k1 x1 t1 Tip
singleR s k1 x1 (Bin _ k2 x2 t1 t2) t3  = bin s k2 x2 t1 (bin s k1 x1 t2 t3)
singleR s k1 x1 Tip t2 = bin s k1 x1 Tip t2

doubleL, doubleR :: Sized a -> k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
doubleL s k1 x1 t1 (Bin _ k2 x2 (Bin _ k3 x3 t2 t3) t4) = bin s k3 x3 (bin s k1 x1 t1 t2) (bin s k2 x2 t3 t4)
doubleL s k1 x1 t1 t2 = singleL s k1 x1 t1 t2
doubleR s k1 x1 (Bin _ k2 x2 t1 (Bin _ k3 x3 t2 t3)) t4 = bin s k3 x3 (bin s k2 x2 t1 t2) (bin s k1 x1 t3 t4)
doubleR s k1 x1 t1 t2 = singleR s k1 x1 t1 t2

bin :: Sized a -> k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
bin s k x l r
  = Bin (size l + size r + s x) k x l r