{-# LANGUAGE BangPatterns, UnboxedTuples, TypeFamilies, PatternGuards, MagicHash, CPP, TupleSections #-}

module Data.TrieMap.OrdMap () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Sized
import Data.TrieMap.Modifiers

import Control.Applicative
import Control.Monad hiding (join, fmap)

import Prelude hiding (lookup, foldr, foldl, fmap)

import GHC.Exts

#define DELTA 5#
#define RATIO 2#

type OrdMap k = TrieMap (Ordered k)

data Path k a =
	Root
	| LeftBin k a !(Path k a) !(OrdMap k a)
	| RightBin k a !(OrdMap k a) !(Path k a)

singletonMaybe :: Sized a => k -> Maybe a -> OrdMap k a
singletonMaybe k = maybe Tip (singleton k)

-- | @'TrieMap' ('Ordered' k) a@ is based on "Data.Map".
instance Ord k => TrieKey (Ordered k) where
	Ord k1 =? Ord k2	= k1 == k2
	Ord k1 `cmp` Ord k2	= k1 `compare` k2
  
	data TrieMap (Ordered k) a = Tip 
              | Bin Int# k a !(OrdMap k a) !(OrdMap k a)
        data Hole (Ordered k) a = 
        	Empty k !(Path k a)
        	| Full k !(Path k a) !(OrdMap k a) !(OrdMap k a)
	emptyM = Tip
	singletonM (Ord k) = singleton k
	lookupM (Ord k) = lookup k
	getSimpleM Tip			= Null
	getSimpleM (Bin _ _ a Tip Tip)	= Singleton a
	getSimpleM _			= NonSimple
	sizeM = size#
	traverseM = traverse
	foldrM = foldr
	foldlM = foldl
	fmapM = fmap
	mapMaybeM = mapMaybe
	mapEitherM = mapEither
	isSubmapM = isSubmap
	fromAscListM  f xs = fromAscList f [(k, a) | (Ord k, a) <- xs]
	fromDistAscListM  xs = fromDistinctAscList  [(k, a) | (Ord k, a) <- xs]
	unionM _ Tip m2 = m2
	unionM _ m1 Tip = m1
	unionM f m1 m2 = hedgeUnion f (const LT) (const GT) m1 m2
	isectM = isect
	diffM _ Tip _ = Tip
	diffM _ m1 Tip = m1
	diffM f m1 m2 = hedgeDiff f (const LT) (const GT) m1 m2
	
	singleHoleM (Ord k) = Empty k Root
	beforeM a (Empty k path) = before (singletonMaybe  k a) path
	beforeM a (Full k path l _) = before t path
		where	t = case a of
				Nothing	-> l
				Just a	-> insertMax k a l
	afterM  a (Empty k path) = after (singletonMaybe  k a) path
	afterM  a (Full k path _ r) = after t path
		where	t = case a of
				Nothing	-> r
				Just a	-> insertMin  k a r
	searchM (Ord k) = search k Root
	indexM i# = indexT Root i# where
		indexT path i# (Bin _ kx x l r) 
		  | i# <# sl#	= indexT (LeftBin kx x path r) i# l
		  | i# <# sx#	= (# i# -# sl#, x, Full kx path l r #)
		  | otherwise	= indexT (RightBin kx x l path) (i# -# sx#) r
			where	!sl# = size# l
				!sx# = getSize# x +# sl#
		indexT _ _ _ = indexFail ()
	extractHoleM = extractHole Root where
		extractHole path (Bin _ kx x l r) =
			extractHole (LeftBin kx x path r) l `mplus`
			return (x, Full kx path l r) `mplus`
			extractHole (RightBin kx x l path) r
		extractHole _ _ = mzero
	assignM x (Empty k path) = rebuild (maybe Tip (singleton k) x) path
	assignM x (Full k path l r) = rebuild (joinMaybe k x l r) path
	
	unifyM (Ord k1) a1 (Ord k2) a2 = case compare k1 k2 of
		EQ	-> Left $ Empty k1 Root
		LT	-> Right $ bin k1 a1 Tip (singleton k2 a2)
		GT	-> Right $ bin k1 a1 (singleton k2 a2) Tip

rebuild :: Sized a => OrdMap k a -> Path k a -> OrdMap k a
rebuild t Root = t
rebuild t (LeftBin kx x path r) = rebuild (balance kx x t r) path
rebuild t (RightBin kx x l path) = rebuild (balance kx x l t) path

lookup :: Ord k => k -> OrdMap k a -> Maybe a
lookup k (Bin _ k' v l r) = case compare k k' of
	LT	-> lookup k l
	EQ	-> Just v
	GT	-> lookup k r
lookup _ _ = Nothing

singleton :: Sized a => k -> a -> OrdMap k a
singleton k a = Bin (getSize# a) k a Tip Tip

traverse :: (Applicative f, Sized b) => (a -> f b) -> OrdMap k a -> f (OrdMap k b)
traverse _ Tip = pure Tip
traverse f (Bin _ k a l r) = balance k <$> f a <*> traverse f l <*> traverse f r

foldr :: (a -> b -> b) -> OrdMap k a -> b -> b
foldr _ Tip = id
foldr f (Bin _ _ a l r) = foldr f l . f a . foldr f r

foldl :: (b -> a -> b) -> OrdMap k a -> b -> b
foldl _ Tip = id
foldl f (Bin _ _ a l r) = foldl f r . flip f a . foldl f l

fmap :: (Ord k, Sized b) => (a -> b) -> OrdMap k a -> OrdMap k b
fmap f (Bin _ k a l r) = join k (f a) (fmap f l) (fmap f r)
fmap _ _ = Tip

mapMaybe :: (Ord k, Sized b) => (a -> Maybe b) -> OrdMap k a -> OrdMap k b
mapMaybe f (Bin _ k a l r) = joinMaybe  k (f a) (mapMaybe f l) (mapMaybe f r)
mapMaybe _ _ = Tip

mapEither :: (Ord k, Sized b, Sized c) => (a -> (# Maybe b, Maybe c #)) ->
	OrdMap k a -> (# OrdMap k b, OrdMap k c #)
mapEither f (Bin _ k a l r) = (# joinMaybe k aL lL rL, joinMaybe k aR lR rR #)
  where !(# aL, aR #) = f a; !(# lL, lR #) = mapEither f l; !(# rL, rR #) = mapEither f r
mapEither _ _ = (# Tip, Tip #)

splitLookup :: (Ord k, Sized a) => k -> OrdMap k a -> (# OrdMap k a, Maybe a, OrdMap k a #)
splitLookup k m = case m of
	Tip	-> (# Tip, Nothing, Tip #)
	Bin _ kx x l r -> case compare k kx of
		LT	-> let !(# lL, ans, lR #) = splitLookup k l in (# lL, ans, join kx x lR r #)
		EQ	-> (# l, Just x, r #)
		GT	-> let !(# rL, ans, rR #) = splitLookup k r in (# join kx x l rL, ans, rR #)

isSubmap :: (Ord k, Sized a, Sized b) => LEq a b -> LEq (OrdMap k a) (OrdMap k b)
isSubmap _ Tip _ = True
isSubmap _ _ Tip = False
isSubmap (<=) (Bin _ kx x l r) t = case found of
	  Nothing	-> False
	  Just y	-> x <= y && isSubmap (<=) l lt && isSubmap (<=) r gt
  where !(# lt, found, gt #) = splitLookup kx t

fromAscList :: (Eq k, Sized a) => (a -> a -> a) -> [(k, a)] -> OrdMap k a
fromAscList f xs = fromDistinctAscList (combineEq xs) where
	combineEq (x:xs) = combineEq' x xs
	combineEq [] = []
	
	combineEq' z [] = [z]
	combineEq' (kz, zz) (x@(kx, xx):xs)
		| kz == kx	= combineEq' (kx, f xx zz) xs
		| otherwise	= (kz,zz):combineEq' x xs

fromDistinctAscList :: Sized a => [(k, a)] -> OrdMap k a
fromDistinctAscList xs = build const (length xs) xs
  where
    -- 1) use continutations so that we use heap space instead of stack space.
    -- 2) special case for n==5 to build bushier trees. 
    build c 0 xs'  = c Tip xs'
    build c 5 xs'  = case xs' of
                      ((k1,x1):(k2,x2):(k3,x3):(k4,x4):(k5,x5):xx) 
                            -> c (bin k4 x4 (bin k2 x2 (singleton k1 x1) (singleton k3 x3)) (singleton k5 x5)) xx
                      _ -> error "fromDistinctAscList build"
    build c n xs'  = seq nr $ build (buildR nr c) nl xs'
                   where
                     nl = n `div` 2
                     nr = n - nl - 1

    buildR n c l ((k,x):ys) = build (buildB l k x c) n ys
    buildR _ _ _ []         = error "fromDistinctAscList buildR []"
    buildB l k x c r zs     = c (bin k x l r) zs

hedgeUnion :: (Ord k, Sized a)
                  => (a -> a -> Maybe a)
                  -> (k -> Ordering) -> (k -> Ordering)
                  -> OrdMap k a -> OrdMap k a -> OrdMap k a
hedgeUnion _ _     _     t1 Tip
  = t1
hedgeUnion _ cmplo cmphi Tip (Bin _ kx x l r)
  = join kx x (filterGt  cmplo l) (filterLt  cmphi r)
hedgeUnion f cmplo cmphi (Bin _ kx x l r) t2
  = joinMaybe  kx newx (hedgeUnion  f cmplo cmpkx l lt) 
                (hedgeUnion  f cmpkx cmphi r gt)
  where
    cmpkx k     = compare kx k
    lt          = trim cmplo cmpkx t2
    (found,gt)  = trimLookupLo kx cmphi t2
    newx        = case found of
                    Nothing -> Just x
                    Just (_,y) -> f x y

filterGt :: (Ord k, Sized a) => (k -> Ordering) -> OrdMap k a -> OrdMap k a
filterGt _   Tip = Tip
filterGt cmp (Bin _ kx x l r)
  = case cmp kx of
      LT -> join kx x (filterGt  cmp l) r
      GT -> filterGt  cmp r
      EQ -> r
      
filterLt :: (Ord k, Sized a) => (k -> Ordering) -> OrdMap k a -> OrdMap k a
filterLt _   Tip = Tip
filterLt cmp (Bin _ kx x l r)
  = case cmp kx of
      LT -> filterLt cmp l
      GT -> join kx x l (filterLt  cmp r)
      EQ -> l

trim :: (k -> Ordering) -> (k -> Ordering) -> OrdMap k a -> OrdMap k a
trim _     _     Tip = Tip
trim cmplo cmphi t@(Bin _ kx _ l r)
  = case cmplo kx of
      LT -> case cmphi kx of
              GT -> t
              _  -> trim cmplo cmphi l
      _  -> trim cmplo cmphi r
              
trimLookupLo :: Ord k => k -> (k -> Ordering) -> OrdMap k a -> (Maybe (k,a), OrdMap k a)
trimLookupLo _  _     Tip = (Nothing,Tip)
trimLookupLo lo cmphi t@(Bin _ kx x l r)
  = case compare lo kx of
      LT -> case cmphi kx of
              GT -> ((lo,) <$> lookup lo t, t)
              _  -> trimLookupLo lo cmphi l
      GT -> trimLookupLo lo cmphi r
      EQ -> (Just (kx,x),trim (compare lo) cmphi r)

isect :: (Ord k, Sized a, Sized b, Sized c) => (a -> b -> Maybe c) -> OrdMap k a -> OrdMap k b -> OrdMap k c
isect f t1@Bin{} (Bin _ k2 x2 l2 r2) 
  = joinMaybe k2 (found >>= \ x1' -> f x1' x2) tl tr
  where	!(# found, hole #) = search k2 Root t1
	tl = isect f (beforeM Nothing hole) l2
	tr = isect f (afterM Nothing hole) r2
isect _ _ _ = Tip

hedgeDiff :: (Ord k, Sized a)
                 => (a -> b -> Maybe a)
                 -> (k -> Ordering) -> (k -> Ordering)
                 -> OrdMap k a -> OrdMap k b -> OrdMap k a
hedgeDiff _ _     _     Tip _
  = Tip
hedgeDiff _ cmplo cmphi (Bin _ kx x l r) Tip
  = join kx x (filterGt  cmplo l) (filterLt  cmphi r)
hedgeDiff  f cmplo cmphi t (Bin _ kx x l r) 
  = case found of
      Nothing -> merge  tl tr
      Just (ky,y) -> 
          case f y x of
            Nothing -> merge tl tr
            Just z  -> join ky z tl tr
  where
    cmpkx k     = compare kx k   
    lt          = trim cmplo cmpkx t
    (found,gt)  = trimLookupLo kx cmphi t
    tl          = hedgeDiff f cmplo cmpkx lt l
    tr          = hedgeDiff f cmpkx cmphi gt r

joinMaybe :: (Ord k, Sized a) => k -> Maybe a -> OrdMap k a -> OrdMap k a -> OrdMap k a
joinMaybe kx = maybe merge (join kx)

join :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
join kx x Tip r  = insertMin  kx x r
join kx x l Tip  = insertMax  kx x l
join kx x l@(Bin sL# ky y ly ry) r@(Bin sR# kz z lz rz)
  | DELTA *# sL# <=# sR# = balance kz z (join kx x l lz) rz
  | DELTA *# sR# <=# sL# = balance ky y ly (join kx x ry r)
  | otherwise             = bin kx x l r

-- insertMin and insertMax don't perform potentially expensive comparisons.
insertMax,insertMin :: Sized a => k -> a -> OrdMap k a -> OrdMap k a
insertMax kx x t
  = case t of
      Tip -> singleton kx x
      Bin _ ky y l r
          -> balance ky y l (insertMax kx x r)
             
insertMin kx x t
  = case t of
      Tip -> singleton kx x
      Bin _ ky y l r
          -> balance ky y (insertMin kx x l) r
             
{--------------------------------------------------------------------
  [merge l r]: merges two trees.
--------------------------------------------------------------------}
merge :: Sized a => OrdMap k a -> OrdMap k a -> OrdMap k a
merge Tip r   = r
merge l Tip   = l
merge l@(Bin sL# kx x lx rx) r@(Bin sR# ky y ly ry)
  | DELTA *# sL# <=# sR# = balance ky y (merge l ly) ry
  | DELTA *# sR# <=# sL# = balance kx x lx (merge rx r)
  | otherwise		  = glue l r

{--------------------------------------------------------------------
  [glue l r]: glues two trees together.
  Assumes that [l] and [r] are already balanced with respect to each other.
--------------------------------------------------------------------}
glue :: Sized a => OrdMap k a -> OrdMap k a -> OrdMap k a
glue Tip r = r
glue l Tip = l
glue l r
  | size# l ># size# r	= let !(# f, l' #) = deleteFindMax (\ k a -> (# balance k a, Nothing #)) l in f l' r
  | otherwise		= let !(# f, r' #) = deleteFindMin (\ k a -> (# balance k a, Nothing #)) r in f l r'

deleteFindMin :: Sized a => (k -> a -> (# x, Maybe a #)) -> OrdMap k a -> (# x, OrdMap k a #)
deleteFindMin f t 
  = case t of
      Bin _ k x Tip r	-> onSnd (maybe r (\ y' -> bin k y' Tip r)) (f k) x
      Bin _ k x l r	-> onSnd (\ l' -> balance k x l' r) (deleteFindMin f) l
      _			-> (# error "Map.deleteFindMin: can not return the minimal element of an empty fmap", Tip #)

deleteFindMax :: Sized a => (k -> a -> (# x, Maybe a #)) -> OrdMap k a -> (# x, OrdMap k a #)
deleteFindMax f t
  = case t of
      Bin _ k x l Tip -> onSnd (maybe l (\ y -> bin k y l Tip)) (f k) x
      Bin _ k x l r   -> onSnd (balance k x l) (deleteFindMax f) r
      Tip             -> (# error "Map.deleteFindMax: can not return the maximal element of an empty fmap", Tip #)

size# :: OrdMap k a -> Int#
size# Tip = 0#
size# (Bin sz _ _ _ _) = sz

balance :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
balance k x l r
  | sR# >=# (DELTA *# sL#)	= rotateL  k x l r
  | sL# >=# (DELTA *# sR#)	= rotateR  k x l r
  | otherwise			= Bin sX# k x l r
  where
    !sL# = size# l
    !sR# = size# r
    !sX# = sL# +# sR# +# getSize# x

-- rotate
rotateL :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
rotateL k x l r@(Bin _ _ _ ly ry)
  | sL# <# (RATIO *# sR#)	= singleL k x l r
  | otherwise			= doubleL k x l r
  where	!sL# = size# ly
  	!sR# = size# ry
rotateL _ _ _ Tip = error "rotateL Tip"

rotateR :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
rotateR k x l@(Bin _ _ _ ly ry) r
  | sR# <# (RATIO *# sL#)	= singleR k x l r
  | otherwise			= doubleR k x l r
  where	!sL# = size# ly
  	!sR# = size# ry
rotateR _ _ _ _ = error "rotateR Tip"

-- basic rotations
singleL, singleR :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
singleL k1 x1 t1 (Bin _ k2 x2 t2 t3)  = bin k2 x2 (bin k1 x1 t1 t2) t3
singleL k1 x1 t1 Tip = bin k1 x1 t1 Tip
singleR  k1 x1 (Bin _ k2 x2 t1 t2) t3  = bin k2 x2 t1 (bin k1 x1 t2 t3)
singleR  k1 x1 Tip t2 = bin k1 x1 Tip t2

doubleL, doubleR :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
doubleL  k1 x1 t1 (Bin _ k2 x2 (Bin _ k3 x3 t2 t3) t4) = bin k3 x3 (bin k1 x1 t1 t2) (bin k2 x2 t3 t4)
doubleL  k1 x1 t1 t2 = singleL k1 x1 t1 t2
doubleR  k1 x1 (Bin _ k2 x2 t1 (Bin _ k3 x3 t2 t3)) t4 = bin k3 x3 (bin k2 x2 t1 t2) (bin k1 x1 t3 t4)
doubleR  k1 x1 t1 t2 = singleR  k1 x1 t1 t2

bin :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a
bin k x l r
  = Bin (size# l +# size# r +# getSize# x) k x l r

before :: Sized a => OrdMap k a -> Path k a -> OrdMap k a
before t (LeftBin _ _ path _) = before t path
before t (RightBin k a l path) = before (join k a l t) path
before t _ = t

after :: Sized a => OrdMap k a -> Path k a -> OrdMap k a
after t (LeftBin k a path r) = after (join k a t r) path
after t (RightBin _ _ _ path) = after t path
after t _ = t

search :: Ord k => k -> Path k a -> OrdMap k a -> (# Maybe a, Hole (Ordered k) a #)
search k path Tip = (# Nothing, Empty k path #)
search k path (Bin _ kx x l r) = case compare k kx of
	LT	-> search k (LeftBin kx x path r) l
	EQ	-> (# Just x, Full k path l r #)
	GT	-> search k (RightBin kx x l path) r