{-# LANGUAGE BangPatterns, Rank2Types, CPP, MagicHash, PatternGuards, MultiParamTypeClasses, TypeFamilies #-}

module Data.TrieMap.IntMap () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative
import Data.TrieMap.Sized

import Control.Applicative (Applicative(..), (<$>))
import Control.Arrow

import Data.Bits
import Data.Maybe
import Data.Monoid
import Data.Word

#if __GLASGOW_HASKELL__ >= 503
import GHC.Exts ( Word(..), Int(..), shiftRL# )
#elif __GLASGOW_HASKELL__
import Word
import GlaExts ( Word(..), Int(..), shiftRL# )
#else
import Data.Word
#endif

import Prelude hiding (lookup, null, foldl, foldr)

type Nat = Word

data IntMap a ix = Nil
              | Tip {-# UNPACK #-} !Size {-# UNPACK #-} !Key (a ix)
              | Bin {-# UNPACK #-} !Size {-# UNPACK #-} !Prefix {-# UNPACK #-} !Mask !(IntMap a ix) !(IntMap a ix) 
		deriving (Show)
type instance TrieMap Int = IntMap

type Prefix = Int
type Mask   = Int
type Key    = Int
type Size   = Int

instance TrieKey Int IntMap where
	emptyM = Nil
	nullM = null
	sizeM _ = size
	lookupM = lookup . natFromInt
	lookupIxM _ = lookupIx . natFromInt
	assocAtM _ = fromJust .: assocAt
	updateAtM = updateAt
	alterM = alter
	traverseWithKeyM = traverseWithKey
	foldWithKeyM = foldr
	foldlWithKeyM = foldl
	mapEitherM = mapEither
	splitLookupM = splitLookup
	unionM = unionWithKey
	isectM = intersectionWithKey
	diffM = differenceWithKey
	extractMinM _ = First . minViewWithKey
	extractMaxM _ = Last . maxViewWithKey
	alterMinM = updateMinWithKey
	alterMaxM = updateMaxWithKey
	isSubmapM = isSubmapOfBy

natFromInt :: Int -> Nat
natFromInt = fromIntegral

intFromNat :: Nat -> Int
intFromNat = fromIntegral

shiftRL :: Nat -> Key -> Nat
#if __GLASGOW_HASKELL__
{--------------------------------------------------------------------
  GHC: use unboxing to get @shiftRL@ inlined.
--------------------------------------------------------------------}
shiftRL (W# x) (I# i)
  = W# (shiftRL# x i)
#else
shiftRL x i   = shiftR x i
#endif


size :: IntMap a ix -> Int
size Nil = 0
size (Tip s _ _) = s
size (Bin s _ _ _ _) = s

null :: IntMap a ix -> Bool
null Nil = True
null _ = False

lookup :: Nat -> IntMap a ix -> Maybe (a ix)
lookup k (Bin _ _ m l r) = lookup k (if zeroN k (natFromInt m) then l else r)
lookup k (Tip _ kx x)
	| k == natFromInt kx	= Just x
lookup _ _ = Nothing

lookupIx :: Nat -> IntMap a ix -> Maybe (Int, a ix)
lookupIx k t = case t of
	Bin _ 0 m l r | m < 0	-> if zeroN k (natFromInt m) then lookupIx' (size r) k l else lookupIx' 0 k r
	Bin{}	-> lookupIx' 0 k t
	Tip _ k x -> return (0, x)
	Nil	-> Nothing

assocAt :: Int -> IntMap a ix -> Maybe (Int, Key, a ix)
assocAt !i t = case t of
	Bin _ 0 m l r | m < 0	-> let sr = size r in
		if i < sr then assocAt' 0 i r else assocAt' sr (i - sr) l
	Bin{} -> assocAt' 0 i t
	Tip _ k x -> return (0, k, x)
	_	-> Nothing

assocAt' :: Int -> Int -> IntMap a ix -> Maybe (Int, Key, a ix)
assocAt' !i0 !i (Bin _ _ _ l r)
	| i < sl	= assocAt' i0 i l
	| otherwise	= assocAt' (i0 + sl) (i - sl) r
	where	sl = size l
assocAt' i0 _ (Tip _ k x) = return (i0, k, x)
assocAt' _ _ _ = Nothing

updateAt :: Sized a -> (Int -> Key -> a ix -> Maybe (a ix)) -> Int -> IntMap a ix -> IntMap a ix
updateAt s f !i t = case t of
	Bin _ 0 m l r | m < 0	-> let sr = size r in
		if i < sr then updateAt' s 0 f i r else updateAt' s sr f (i - sr) l
	Bin{}	-> updateAt' s 0 f i t
	Tip _ kx x -> singletonMaybe s kx (f 0 kx x)
	Nil	-> Nil

updateAt' :: Sized a -> Int -> (Int -> Key -> a ix -> Maybe (a ix)) -> Int -> IntMap a ix -> IntMap a ix
updateAt' s !i0 f !i t = case t of
	Bin _ p m l r -> let sl = size l in
		if i < sl then bin p m (updateAt' s i0 f i l) r 
			else bin p m l (updateAt' s (i0 + sl) f (i - sl) r)

lookupIx' :: Int -> Nat -> IntMap a ix -> Maybe (Int, a ix)
lookupIx' !i k t = case t of
	Bin _ _ m l r
		| zeroN k (natFromInt m)	-> lookupIx' i k l
		| otherwise			-> lookupIx' (i + size l) k r
	Tip _ kx x
		| k == natFromInt kx		-> Just (i, x)
	_ -> Nothing

singleton :: Sized a -> Key -> a ix -> IntMap a ix
singleton s k a = Tip (s a) k a

singletonMaybe :: Sized a -> Key -> Maybe (a ix) -> IntMap a ix
singletonMaybe s k = maybe Nil (singleton s k)

alter :: Sized a -> (Maybe (a ix) -> Maybe (a ix)) -> Int -> IntMap a ix -> IntMap a ix
alter s f k t = case t of
	Bin sz p m l r
		| nomatch k p m	-> join k (singletonMaybe s k (f Nothing)) p t
		| zero k m	-> bin p m (alter s f k l) r
		| otherwise	-> bin p m l (alter s f k r)
	Tip sz ky y
		| k == ky	-> singletonMaybe s k (f (Just y))
		| Just x <- f Nothing
				-> join k (Tip (s x) k x) ky t
		| otherwise	-> Tip sz ky y
	Nil	-> singletonMaybe s k (f Nothing)

traverseWithKey :: Applicative f => Sized b -> (Key -> a ix -> f (b ix)) -> IntMap a ix -> f (IntMap b ix)
traverseWithKey s f t = case t of
	Nil		-> pure Nil
	Tip _ kx x	-> singleton s kx <$> f kx x
	Bin _ p m l r	-> bin p m <$> traverseWithKey s f l <*> traverseWithKey s f r

foldr :: (Key -> a ix -> b -> b) -> IntMap a ix -> b -> b
foldr f t
  = case t of
      Bin _ 0 m l r | m < 0 -> foldr' f r . foldr' f l  -- put negative numbers before.
      Bin _ _ _ _ _ -> foldr' f t
      Tip _ k x     -> f k x
      Nil         -> id

foldr' :: (Key -> a ix -> b -> b) -> IntMap a ix -> b -> b
foldr' f t
  = case t of
      Bin _ _ _ l r -> foldr' f l . foldr' f r
      Tip _ k x     -> f k x
      Nil         -> id

foldl, foldl' :: (Key -> b -> a ix -> b) -> IntMap a ix -> b -> b
foldl f t
  = case t of
      Bin _ 0 m l r | m < 0 -> foldl' f l . foldl' f r  -- put negative numbers before.
      Bin _ _ _ _ _ -> foldl' f t
      Tip _ k x     -> flip (f k) x
      Nil           -> id
foldl' f t
  = case t of
      Bin _ _ _ l r -> foldl' f r . foldl' f l
      Tip _ k x     -> flip (f k) x
      Nil         -> id

mapEither :: Sized b -> Sized c -> EitherMap Key (a ix) (b ix) (c ix) ->
	IntMap a ix -> (IntMap b ix, IntMap c ix)
mapEither s1 s2 f (Bin _ p m l r) = case (mapEither s1 s2 f l, mapEither s1 s2 f r) of
	((lL, lR), (rL, rR)) -> (bin p m lL rL, bin p m lR rR)
mapEither s1 s2 f (Tip _ kx x) = (singletonMaybe s1 kx *** singletonMaybe s2 kx) (f kx x)

splitLookup :: Sized a -> SplitMap (a ix) x -> Key -> IntMap a ix -> (IntMap a ix ,Maybe x,IntMap a ix)
splitLookup s f k t
  = case t of
      Bin _ _ m l r
          | m < 0 -> (if k >= 0 -- handle negative numbers.
                      then let (lt,found,gt) = splitLookup' s f k l in (union s r lt,found, gt)
                      else let (lt,found,gt) = splitLookup' s f k r in (lt,found, union s gt l))
          | otherwise   -> splitLookup' s f k t
      Tip _ ky y 
        | k>ky      -> (t,Nothing,Nil)
        | k<ky      -> (Nil,Nothing,t)
        | otherwise -> singletonMaybe s k `sides` f y
      Nil -> (Nil,Nothing,Nil)

splitLookup' :: Sized a -> SplitMap (a ix) x -> Key -> IntMap a ix -> (IntMap a ix ,Maybe x,IntMap a ix)
splitLookup' s f k t
  = case t of
      Bin _ p m l r
        | nomatch k p m -> if k>p then (t,Nothing,Nil) else (Nil,Nothing,t)
        | zero k m  -> let (lt,found,gt) = splitLookup s f k l in (lt,found,union s gt r)
        | otherwise -> let (lt,found,gt) = splitLookup s f k r in (union s l lt,found,gt)
      Tip _ ky y 
        | k>ky      -> (t,Nothing,Nil)
        | k<ky      -> (Nil,Nothing,t)
        | otherwise -> singletonMaybe s k `sides` f y
      Nil -> (Nil,Nothing,Nil)

union :: Sized a -> IntMap a ix -> IntMap a ix -> IntMap a ix
union s t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = union1
  | shorter m2 m1  = union2
  | p1 == p2       = bin p1 m1 (union s l1 l2) (union s r1 r2)
  | otherwise      = join p1 t1 p2 t2
  where
    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
            | zero p2 m1        = bin p1 m1 (union s l1 t2) r1
            | otherwise         = bin p1 m1 l1 (union s r1 t2)

    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
            | zero p1 m2        = bin p2 m2 (union s t1 l2) r2
            | otherwise         = bin p2 m2 l2 (union s t1 r2)
union s (Tip _ k x) t = alter s (const (Just x)) k t
union s t (Tip _ k x) = alter s (Just . fromMaybe x) k t  -- right bias
union _ Nil t       = t
union _ t Nil       = t

unionWithKey :: Sized a -> UnionFunc Key (a ix) -> IntMap a ix -> IntMap a ix -> IntMap a ix
unionWithKey s f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = union1
  | shorter m2 m1  = union2
  | p1 == p2       = bin p1 m1 (unionWithKey s f l1 l2) (unionWithKey s f r1 r2)
  | otherwise      = join p1 t1 p2 t2
  where
    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
            | zero p2 m1        = bin p1 m1 (unionWithKey s f l1 t2) r1
            | otherwise         = bin p1 m1 l1 (unionWithKey s f r1 t2)

    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
            | zero p1 m2        = bin p2 m2 (unionWithKey s f t1 l2) r2
            | otherwise         = bin p2 m2 l2 (unionWithKey s f t1 r2)
unionWithKey s f (Tip _ k x) t = alter s (maybe (Just x) (f k x)) k t
unionWithKey s f t (Tip _ k x) = alter s (maybe (Just x) (flip (f k) x)) k t
unionWithKey _ _ Nil t  = t
unionWithKey _ _ t Nil  = t

intersectionWithKey :: Sized c -> IsectFunc Key (a ix) (b ix) (c ix) -> IntMap a ix -> IntMap b ix -> IntMap c ix
intersectionWithKey s f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = intersection1
  | shorter m2 m1  = intersection2
  | p1 == p2       = bin p1 m1 (intersectionWithKey s f l1 l2) (intersectionWithKey s f r1 r2)
  | otherwise      = Nil
  where
    intersection1 | nomatch p2 p1 m1  = Nil
                  | zero p2 m1        = intersectionWithKey s f l1 t2
                  | otherwise         = intersectionWithKey s f r1 t2

    intersection2 | nomatch p1 p2 m2  = Nil
                  | zero p1 m2        = intersectionWithKey s f t1 l2
                  | otherwise         = intersectionWithKey s f t1 r2

intersectionWithKey s f (Tip _ k x) t2
  = singletonMaybe s k (lookup (natFromInt k) t2 >>= f k x)
intersectionWithKey s f t1 (Tip _ k y) 
  = singletonMaybe s k (lookup (natFromInt k) t1 >>= flip (f k) y)
intersectionWithKey _ _ Nil _ = Nil
intersectionWithKey _ _ _ Nil = Nil

differenceWithKey :: Sized a -> (Key -> a ix -> b ix -> Maybe (a ix)) -> IntMap a ix -> IntMap b ix -> IntMap a ix
differenceWithKey s f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = difference1
  | shorter m2 m1  = difference2
  | p1 == p2       = bin p1 m1 (differenceWithKey s f l1 l2) (differenceWithKey s f r1 r2)
  | otherwise      = t1
  where
    difference1 | nomatch p2 p1 m1  = t1
                | zero p2 m1        = bin p1 m1 (differenceWithKey s f l1 t2) r1
                | otherwise         = bin p1 m1 l1 (differenceWithKey s f r1 t2)

    difference2 | nomatch p1 p2 m2  = t1
                | zero p1 m2        = differenceWithKey s f t1 l2
                | otherwise         = differenceWithKey s f t1 r2

differenceWithKey s f t1@(Tip _ k x) t2 
  = maybe t1 (singletonMaybe s k . f k x) (lookup (natFromInt k) t2)
differenceWithKey _ _ Nil _       = Nil
differenceWithKey s f t (Tip _ k y) = alter s (>>= flip (f k) y) k t
differenceWithKey _ _ t Nil       = t

isSubmapOfBy :: LEq (a ix) (b ix) -> LEq (IntMap a ix) (IntMap b ix)
isSubmapOfBy (<=) t1@(Bin _ p1 m1 l1 r1) (Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = False
  | shorter m2 m1  = match p1 p2 m2 && (if zero p1 m2 then isSubmapOfBy (<=) t1 l2
                                                      else isSubmapOfBy (<=) t1 r2)                     
  | otherwise      = (p1==p2) && isSubmapOfBy (<=) l1 l2 && isSubmapOfBy (<=) r1 r2
isSubmapOfBy _         (Bin _ _ _ _ _) _ = False
isSubmapOfBy (<=) (Tip _ k x) t     = maybe False (x <=) (lookup (natFromInt k) t)
isSubmapOfBy _         Nil _           = True


maxViewWithKey, minViewWithKey :: IntMap a ix -> Maybe ((Key, a ix), IntMap a ix)
maxViewWithKey t
    = case t of
        Bin _ p m l r | m < 0 -> let (result, t') = maxViewUnsigned l in Just (result, bin p m t' r)
        Bin _ p m l r         -> let (result, t') = maxViewUnsigned r in Just (result, bin p m l t')
        Tip _ k y -> Just ((k,y), Nil)
        Nil -> Nothing

maxViewUnsigned, minViewUnsigned :: IntMap a ix -> ((Key, a ix), IntMap a ix)
maxViewUnsigned t 
    = case t of
        Bin _ p m l r -> let (result,t') = maxViewUnsigned r in (result,bin p m l t')
        Tip _ k y -> ((k,y), Nil)
        Nil -> error "maxViewUnsigned Nil"

-- 
-- minViewWithKey :: IntMap a ix -> Maybe ((Key, a ix), IntMap a ix)
minViewWithKey t
    = case t of
        Bin _ p m l r | m < 0 -> let (result, t') = minViewUnsigned r in Just (result, bin p m l t')
        Bin _ p m l r         -> let (result, t') = minViewUnsigned l in Just (result, bin p m t' r)
        Tip _ k y -> Just ((k,y),Nil)
        Nil -> Nothing

-- minViewUnsigned :: IntMap a ix -> ((Key, a ix), IntMap a ix)
minViewUnsigned t 
    = case t of
        Bin _ p m l r -> let (result,t') = minViewUnsigned l in (result,bin p m t' r)
        Tip _ k y -> ((k,y),Nil)
        Nil -> error "minViewUnsigned Nil"

updateMinWithKey :: Sized a -> (Key -> a ix -> Maybe (a ix)) -> IntMap a ix -> IntMap a ix
updateMinWithKey s f t
    = case t of
        Bin _ p m l r | m < 0 -> let t' = updateMinWithKeyUnsigned s f r in bin p m l t'
        Bin _ p m l r         -> let t' = updateMinWithKeyUnsigned s f l in bin p m t' r
        Tip _ k y -> singletonMaybe s k (f k y)
        Nil -> Nil

updateMinWithKeyUnsigned :: Sized a -> (Key -> a ix -> Maybe (a ix)) -> IntMap a ix -> IntMap a ix
updateMinWithKeyUnsigned s f t
    = case t of
        Bin _ p m l r -> let t' = updateMinWithKeyUnsigned s f l in bin p m t' r
        Tip _ k y -> singletonMaybe s k (f k y)
        Nil -> Nil

updateMaxWithKey :: Sized a -> (Key -> a ix -> Maybe (a ix)) -> IntMap a ix -> IntMap a ix
updateMaxWithKey s f t
    = case t of
        Bin _ p m l r | m < 0 -> let t' = updateMaxWithKeyUnsigned s f l in bin p m t' r
        Bin _ p m l r         -> let t' = updateMaxWithKeyUnsigned s f r in bin p m l t'
        Tip _ k y -> singletonMaybe s k (f k y)
        Nil -> Nil

updateMaxWithKeyUnsigned :: Sized a -> (Key -> a ix -> Maybe (a ix)) -> IntMap a ix -> IntMap a ix
updateMaxWithKeyUnsigned s f t
    = case t of
        Bin _ p m l r -> let t' = updateMaxWithKeyUnsigned s f r in bin p m l t'
        Tip _ k y -> singletonMaybe s k (f k y)
        Nil -> Nil

mask :: Key -> Mask -> Prefix
mask i m
  = maskW (natFromInt i) (natFromInt m)

zero :: Key -> Mask -> Bool
zero i m
  = (natFromInt i) .&. (natFromInt m) == 0

nomatch,match :: Key -> Prefix -> Mask -> Bool
nomatch i p m
  = (mask i m) /= p

match i p m
  = (mask i m) == p

zeroN :: Nat -> Nat -> Bool
zeroN i m = (i .&. m) == 0

maskW :: Nat -> Nat -> Prefix
maskW i m
  = intFromNat (i .&. (complement (m-1) `xor` m))

shorter :: Mask -> Mask -> Bool
shorter m1 m2
  = (natFromInt m1) > (natFromInt m2)

branchMask :: Prefix -> Prefix -> Mask
branchMask p1 p2
  = intFromNat (highestBitMask (natFromInt p1 `xor` natFromInt p2))

highestBitMask :: Nat -> Nat
highestBitMask x0
  = case (x0 .|. shiftRL x0 1) of
     x1 -> case (x1 .|. shiftRL x1 2) of
      x2 -> case (x2 .|. shiftRL x2 4) of
       x3 -> case (x3 .|. shiftRL x3 8) of
        x4 -> case (x4 .|. shiftRL x4 16) of
         x5 -> case (x5 .|. shiftRL x5 32) of   -- for 64 bit platforms
          x6 -> (x6 `xor` (shiftRL x6 1))

join :: Prefix -> IntMap a ix -> Prefix -> IntMap a ix -> IntMap a ix
join p1 t1 p2 t2
  | zero p1 m = bin p m t1 t2
  | otherwise = bin p m t2 t1
  where
    m = branchMask p1 p2
    p = mask p1 m

bin :: Prefix -> Mask -> IntMap a ix -> IntMap a ix -> IntMap a ix
bin _ _ l Nil = l
bin _ _ Nil r = r
bin p m l r   = Bin (size l + size r) p m l r

-- import Data.Monoid
-- import Data.IntMap
-- import qualified Data.IntMap as IMap
-- import Data.Traversable
-- 
-- newtype IntTMap a ix = ITMap (IntMap (a ix))
-- type instance TrieMap Int = IntTMap
-- newtype MaybeF a ix = MF {unF :: Maybe (a ix)}
-- 
-- instance TrieKey Int IntTMap where
-- 	emptyM = ITMap empty
-- 	nullM (ITMap m) = IMap.null m
-- 	alterM _ f k (ITMap m) = ITMap (IMap.alter f k m)
-- 	lookupM k (ITMap m) = IMap.lookup k m
-- 	traverseWithKeyM _ f (ITMap m) = (ITMap . IMap.fromDistinctAscList) <$>
-- 		sequenceA (IMap.foldWithKey (\ k a xs -> (((,) k) <$> f k a):xs) [] m)
-- 	foldWithKeyM f (ITMap m) z = IMap.foldWithKey f z m
-- 	foldlWithKeyM f (ITMap m) z = foldl (\ z (k, a) -> f k z a) z (IMap.assocs m)
-- 	mapEitherM _ _ f (ITMap m) = (ITMap (mapMaybe fst m'), ITMap (mapMaybe snd m')) where
-- 		m' = mapWithKey f m
-- 	splitLookupM _ f k (ITMap m) = ITMap `sides` case splitLookup k m of
-- 		(mL, x, mR)
-- 			| Nothing <- x	-> (mL, Nothing, mR)
-- 			| Just x <- x, (xL, x, xR) <- f x
-- 				-> (mIns k mL xL, x, mIns k mR xR)
-- 		where	mIns k m = maybe m (\ x -> IMap.insert k x m)
-- 	unionM _ f (ITMap m1) (ITMap m2) = ITMap (mapMaybe unF (unionWithKey f' m1' m2')) where
-- 		f' k (MF a) (MF b) = MF (unionMaybe (f k) a b)
-- 		m1' = fmap (MF . Just) m1
-- 		m2' = fmap (MF . Just) m2
-- 	isectM _ f (ITMap m1) (ITMap m2) = ITMap (mapMaybe unF (intersectionWithKey f' m1' m2')) where
-- 		f' k (MF a) (MF b) = MF (isectMaybe (f k) a b)
-- 		m1' = fmap (MF . Just) m1
-- 		m2' = fmap (MF . Just) m2
-- 	diffM _ f (ITMap m1) (ITMap m2) = ITMap (differenceWithKey f m1 m2)
-- 	extractMinM _ (ITMap m) = fmap ITMap <$> First (minViewWithKey m)
-- 	extractMaxM _ (ITMap m) = fmap ITMap <$> Last (maxViewWithKey m)
-- 	alterMinM _ f (ITMap m) = ITMap $ case minViewWithKey m of
-- 		Just ((k, v), m') 
-- 				-> maybe m' (\ v' -> updateMin (const v') m) (f k v)
-- 		Nothing		-> m
-- 	alterMaxM _ f (ITMap m) = ITMap $ case maxViewWithKey m of
-- 		Just ((k, v), m')
-- 				-> maybe m' (\ v' -> updateMax (const v') m) (f k v)
-- 		Nothing		-> m
-- 	isSubmapM (<=) (ITMap m1) (ITMap m2) = isSubmapOfBy (<=) m1 m2
-- 	fromListM _ = ITMap .: fromListWithKey
-- 	fromAscListM _ = ITMap .: fromAscListWithKey
-- 	fromDistAscListM _ = ITMap . fromDistinctAscList