{-# LANGUAGE TemplateHaskell, TypeOperators, UndecidableInstances, BangPatterns, Rank2Types, CPP, MagicHash, PatternGuards, MultiParamTypeClasses, TypeFamilies #-}

module Data.TrieMap.IntMap () where

import Data.TrieMap.TrieKey
-- import Data.TrieMap.MultiRec.Base
-- import Data.TrieMap.Applicative
import Data.TrieMap.Sized
import Data.TrieMap.CPair
-- import Data.TrieMap.ReverseMap
-- import Data.TrieMap.Rep
-- import Data.TrieMap.Rep.TH

import Control.Applicative (Applicative(..), Alternative(..), (<$>))
import Control.Arrow
import Control.Monad (MonadPlus(..))

import Data.Bits
import Data.Maybe
import Data.Monoid
import Data.Word
-- import Data.Int

-- #if __GLASGOW_HASKELL__ >= 503
-- import GHC.Exts ( Word(..), Int(..), shiftRL# )
-- #elif __GLASGOW_HASKELL__
-- import Word
-- import GlaExts ( Word(..), Int(..), shiftRL# )
-- #else
-- import Data.Word
-- #endif

import Prelude hiding (lookup, null, foldl, foldr)

type Nat = Word32

data WordMap a = Nil
              | Tip {-# UNPACK #-} !Size {-# UNPACK #-} !Key (a)
              | Bin {-# UNPACK #-} !Size {-# UNPACK #-} !Prefix {-# UNPACK #-} !Mask !(WordMap a) !(WordMap a) 
		deriving (Show)
-- data IntMap a = IMap (WordMap a) (WordMap a)
type instance TrieMap Word32 = WordMap
-- type instance TrieMap Int32 = IntMap

type Prefix = Word32
type Mask   = Word32
type Key    = Word32
type Size   = Int

-- type instance RepT WordMap = FamT KeyFam (HFix (U :+: (K Size :*: K Key :*: X) :+:
-- 				(K Size :*: K Prefix :*: K Mask :*: A0 :*: A0)))
-- type instance Rep (WordMap a) = RepT WordMap (Rep a)
-- 
-- -- $(genRepT [d|
--    instance ReprT WordMap where
-- 	toRepT = FamT . toFix where
-- 		toFix = HIn . toFix'
-- 		toFix' Nil = L U
-- 		toFix' (Tip s kx x) = R (L (K s :*: K kx :*: X x))
-- 		toFix' (Bin s p m l r) = R (R (K s :*: K p :*: K m :*: A0 (toFix l) :*: A0 (toFix r)))
-- 	fromRepT (FamT m) = fromFix m where
-- 		fromFix (HIn x) = fromFix' x
-- 		fromFix' L{} = Nil
-- 		fromFix' (R (L (K s :*: K kx :*: X x))) = Tip s kx x
-- 		fromFix' (R (R (K s :*: K p :*: K m :*: A0 l :*: A0 r))) = Bin s p m (fromFix l) (fromFix r) |])

instance TrieKey Word32 WordMap where
	emptyM = Nil
	nullM = null
	sizeM _ = size
	lookupM = lookup
	lookupIxM s = lookupIx s 0
	assocAtM s = assocAt s 0
-- 	updateAtM s r = updateAt s r 0
	alterM = alter
	alterLookupM = alterLookup
	traverseWithKeyM = traverseWithKey
	foldWithKeyM = foldr
	foldlWithKeyM = foldl
	mapEitherM = mapEither
	splitLookupM = splitLookup
	unionM = unionWithKey
	isectM = intersectionWithKey
	diffM = differenceWithKey
	extractM s f = extract s f
-- 	extractMinM s f = First . minViewWithKey s f
-- 	extractMaxM s f = Last . maxViewWithKey s f
-- 	alterMinM = updateMinWithKey
-- 	alterMaxM = updateMaxWithKey
	isSubmapM = isSubmapOfBy

{-instance TrieKey Int32 IntMap where
	emptyM = IMap Nil Nil
	nullM (IMap mN mP) = nullM mN && nullM mP
	sizeM s (IMap mN mP) = sizeM s mN + sizeM s mP
	lookupM k (IMap mN mP)
		| k < 0		= lookupM (fromIntegral (-k)) mN
		| otherwise	= lookupM (fromIntegral k) mP
	lookupIxM s k (IMap mN mP)
		| k < 0		= do	(i, v) <- lookupIx' 0 s (fromIntegral (-k)) mN
					return (sizeM s mN - 1 - i, v)
		| otherwise	= do	(i, v) <- lookupIxM s (fromIntegral k) mP
					return (i + sizeM s mN, v)
	assocAtM s i (IMap mN mP)
		| i < sN, (i', k, a) <- assocAt' s i mN
			= (i', - fromIntegral k, a)
		| (i', k, a) <-assocAtM s (i - sN) mP
			= (i' + sN, fromIntegral k, a)
		where	sN = sizeM s mN
	updateAtM s f i (IMap mN mP)
		| i < sN	= updateAtM s (\ i' k -> f i' (- fromIntegral k)) (sN - 1 - i) mN `IMap` mP
		| otherwise	= mN `IMap` updateAtM s (\ i' k -> f (i' + sN) (fromIntegral k)) (i - sN) mP
		where	sN = sizeM s mN
	alterM s f k (IMap mN mP)
		| k < 0		= alterM s f (fromIntegral (- k)) mN `IMap` mP
		| otherwise	= mN `IMap` alterM s f (fromIntegral k) mP
	traverseWithKeyM s f (IMap mN mP) =
		IMap <$> traverseWithKeyM s (\ k -> f (- fromIntegral k)) mN <*>
			traverseWithKeyM s (f . fromIntegral) mP
	foldWithKeyM f (IMap mN mP) =
		foldlWithKeyM (\ k -> flip (f (- fromIntegral k))) mN . foldWithKeyM (f . fromIntegral) mP
	foldlWithKeyM f (IMap mN mP) =
		foldlWithKeyM (f . fromIntegral) mP . foldWithKeyM (\ k -> flip (f (- fromIntegral k))) mN
	mapEitherM s1 s2 f (IMap mN mP) = (IMap mNL mPL, IMap mNR mPR)
		where	(mNL, mNR) = mapEitherM s1 s2 (\ k -> f (- fromIntegral k)) mN
			(mPL, mPR) = mapEitherM s1 s2 (f . fromIntegral) mP
	splitLookupM s f k (IMap mN mP)
		| k < 0, (mNL, ans, mNR) <- splitLookupM s ((\ (l, x, r) -> (r, x, l)) . f) (fromIntegral (-k)) mN
			= (IMap mNR emptyM, ans, IMap mNL mP)
		| (mPL, ans, mPR) <- splitLookupM s f (fromIntegral k) mP
			= (IMap mN mPL, ans, IMap emptyM mPR)-}

natFromInt :: Word32 -> Nat
natFromInt = id

intFromNat :: Nat -> Word32
intFromNat = id

shiftRL :: Nat -> Key -> Nat
-- #if __GLASGOW_HASKELL__
{--------------------------------------------------------------------
  GHC: use unboxing to get @shiftRL@ inlined.
--------------------------------------------------------------------}
-- shiftRL (W# x) (I# i)
--   = W# (shiftRL# x i)
-- #else
shiftRL x i   = shiftR x (fromIntegral i)
-- #endif


size :: WordMap a -> Int
size Nil = 0
size (Tip s _ _) = s
size (Bin s _ _ _ _) = s

null :: WordMap a -> Bool
null Nil = True
null _ = False

lookup :: Nat -> WordMap a -> Maybe (a)
lookup k (Bin _ _ m l r) = lookup k (if zeroN k m then l else r)
lookup k (Tip _ kx x)
	| k == kx	= Just x
lookup _ _ = Nothing


assocAt :: Sized a -> Int -> Int -> WordMap a -> IndexPos Key a
assocAt s !i0 !i (Bin _ _ _ l r)
	| i < sl, (lb, x, ub) <- assocAt s i0 i l
		= (lb, x, ub <|> fst <$> First (minViewWithKey s (\ k a -> (Asc (i0 + size l) k a, Just a)) r))
	| (lb, x, ub) <- assocAt s (i0 + sl) (i - sl) r
		= (fst <$> Last (maxViewWithKey s (\ k a -> (Asc (i0 + size l - s a) k a, Just a)) l) <|> lb, x, ub)
	where	sl = size l
assocAt _ i0 _ (Tip _ k x) = (mzero, return (Asc i0 k x), mzero)
assocAt _ _ _ _ = (mzero, mzero, mzero)

updateAt :: Sized a -> Round -> Int -> (Int -> Key -> a -> Maybe (a)) -> Int -> WordMap a -> WordMap a
updateAt s True !i0 f !i t = case t of
	Bin _ p m l r -> let sl = size l in
		if i < sl then bin p m (updateAt s True i0 f i l) r 
			else bin p m l (updateAt s True (i0 + sl) f (i - sl) r)
	Tip _ kx x -> singletonMaybe s kx (f i0 kx x)
	_	-> t
updateAt s False !i0 f !i t = case t of
	Bin sz p m l r -> let {sl = size l; mI = maxIx l} in
		if i < mI then bin p m (updateAt s False i0 f i l) r
			else bin p m l (updateAt s False (i0 + sl) f (i - sl) r)
	Tip _ kx x -> singletonMaybe s kx (f i0 kx x)
	_	-> t
	where	maxIx m = maybe (size m) fst (maxViewWithKey s (\ _ a -> (size m - s a, Just a)) m)

lookupIx :: Sized a -> Int -> Nat -> WordMap a -> IndexPos Nat a
lookupIx s !i k t = case t of
	Bin _ _ m l r
		| zeroN k m, (lb, x, ub) <- lookupIx s i k l
			-> (lb, x, ub <|> fst <$> First (minViewWithKey s (\ k a -> (Asc (i + size l) k a, Just a)) r))
		| (lb, x, ub) <- lookupIx s (i + size l) k r
			-> (fst <$> Last (maxViewWithKey s (\ k a -> (Asc (i + size l - s a) k a, Just a)) l) <|> lb, x, ub)
	Tip _ kx x
		| k == kx	-> (mzero, return (Asc i kx x), mzero)
	_ -> (mzero, mzero, mzero)

singleton :: Sized a -> Key -> a -> WordMap a
singleton s k a = Tip (s a) k a

singletonMaybe :: Sized a -> Key -> Maybe (a) -> WordMap a
singletonMaybe s k = maybe Nil (singleton s k)

alter :: Sized a -> (Maybe (a) -> Maybe (a)) -> Key -> WordMap a -> WordMap a
alter s f k t = case t of
	Bin sz p m l r
		| nomatch k p m	-> join k (singletonMaybe s k (f Nothing)) p t
		| zero k m	-> bin p m (alter s f k l) r
		| otherwise	-> bin p m l (alter s f k r)
	Tip sz ky y
		| k == ky	-> singletonMaybe s k (f (Just y))
		| Just x <- f Nothing
				-> join k (Tip (s x) k x) ky t
		| otherwise	-> Tip sz ky y
	Nil	-> singletonMaybe s k (f Nothing)

alterLookup :: Sized a -> (Maybe a -> CPair x (Maybe a)) -> Key -> WordMap a -> CPair x (WordMap a)
alterLookup s f k t = case t of
	Bin sz p m l r
		| nomatch k p m
			-> fmap (\ v -> join k (singletonMaybe s k v) p t) (f Nothing)
		| zero k m
			-> fmap (\ l' -> bin p m l' r) (alterLookup s f k l)
		| otherwise
			-> fmap (\ r' -> bin p m l r') (alterLookup s f k r)
	Tip sz ky y
		| k == ky	-> singletonMaybe s k <$> f (Just y)
		| otherwise	-> fmap (\ v -> join k (singletonMaybe s k v) ky t) (f Nothing)
	Nil	-> singletonMaybe s k <$> f Nothing

traverseWithKey :: Applicative f => Sized b -> (Key -> a -> f (b)) -> WordMap a -> f (WordMap b)
traverseWithKey s f t = case t of
	Nil		-> pure Nil
	Tip _ kx x	-> singleton s kx <$> f kx x
	Bin _ p m l r	-> bin p m <$> traverseWithKey s f l <*> traverseWithKey s f r

foldr :: (Key -> a -> b -> b) -> WordMap a -> b -> b
foldr f t
  = case t of
      Bin _ _ _ l r -> foldr f l . foldr f r
      Tip _ k x     -> f k x
      Nil         -> id

foldl :: (Key -> b -> a -> b) -> WordMap a -> b -> b
foldl f t
  = case t of
      Bin _ _ _ l r -> foldl f r . foldl f l
      Tip _ k x     -> flip (f k) x
      Nil         -> id

mapEither :: Sized b -> Sized c -> EitherMap Key (a) (b) (c) ->
	WordMap a -> (WordMap b, WordMap c)
mapEither s1 s2 f (Bin _ p m l r) = case (mapEither s1 s2 f l, mapEither s1 s2 f r) of
	((lL, lR), (rL, rR)) -> (bin p m lL rL, bin p m lR rR)
mapEither s1 s2 f (Tip _ kx x) = (singletonMaybe s1 kx *** singletonMaybe s2 kx) (f kx x)

splitLookup :: Sized a -> SplitMap (a) x -> Key -> WordMap a -> (WordMap a ,Maybe x,WordMap a)
splitLookup s f k t
  = case t of
      Bin _ p m l r
        | nomatch k p m -> if k>p then (t,Nothing,Nil) else (Nil,Nothing,t)
        | zero k m  -> let (lt,found,gt) = splitLookup s f k l in (lt,found,union s gt r)
        | otherwise -> let (lt,found,gt) = splitLookup s f k r in (union s l lt,found,gt)
      Tip _ ky y 
        | k>ky      -> (t,Nothing,Nil)
        | k<ky      -> (Nil,Nothing,t)
        | otherwise -> singletonMaybe s k `sides` f y
      Nil -> (Nil,Nothing,Nil)

union :: Sized a -> WordMap a -> WordMap a -> WordMap a
union s t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = union1
  | shorter m2 m1  = union2
  | p1 == p2       = bin p1 m1 (union s l1 l2) (union s r1 r2)
  | otherwise      = join p1 t1 p2 t2
  where
    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
            | zero p2 m1        = bin p1 m1 (union s l1 t2) r1
            | otherwise         = bin p1 m1 l1 (union s r1 t2)

    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
            | zero p1 m2        = bin p2 m2 (union s t1 l2) r2
            | otherwise         = bin p2 m2 l2 (union s t1 r2)
union s (Tip _ k x) t = alter s (const (Just x)) k t
union s t (Tip _ k x) = alter s (Just . fromMaybe x) k t  -- right bias
union _ Nil t       = t
union _ t Nil       = t

unionWithKey :: Sized a -> UnionFunc Key (a) -> WordMap a -> WordMap a -> WordMap a
unionWithKey s f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = union1
  | shorter m2 m1  = union2
  | p1 == p2       = bin p1 m1 (unionWithKey s f l1 l2) (unionWithKey s f r1 r2)
  | otherwise      = join p1 t1 p2 t2
  where
    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
            | zero p2 m1        = bin p1 m1 (unionWithKey s f l1 t2) r1
            | otherwise         = bin p1 m1 l1 (unionWithKey s f r1 t2)

    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
            | zero p1 m2        = bin p2 m2 (unionWithKey s f t1 l2) r2
            | otherwise         = bin p2 m2 l2 (unionWithKey s f t1 r2)
unionWithKey s f (Tip _ k x) t = alter s (maybe (Just x) (f k x)) k t
unionWithKey s f t (Tip _ k x) = alter s (maybe (Just x) (flip (f k) x)) k t
unionWithKey _ _ Nil t  = t
unionWithKey _ _ t Nil  = t

intersectionWithKey :: Sized c -> IsectFunc Key (a) (b) (c) -> WordMap a -> WordMap b -> WordMap c
intersectionWithKey s f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = intersection1
  | shorter m2 m1  = intersection2
  | p1 == p2       = bin p1 m1 (intersectionWithKey s f l1 l2) (intersectionWithKey s f r1 r2)
  | otherwise      = Nil
  where
    intersection1 | nomatch p2 p1 m1  = Nil
                  | zero p2 m1        = intersectionWithKey s f l1 t2
                  | otherwise         = intersectionWithKey s f r1 t2

    intersection2 | nomatch p1 p2 m2  = Nil
                  | zero p1 m2        = intersectionWithKey s f t1 l2
                  | otherwise         = intersectionWithKey s f t1 r2

intersectionWithKey s f (Tip _ k x) t2
  = singletonMaybe s k (lookup (natFromInt k) t2 >>= f k x)
intersectionWithKey s f t1 (Tip _ k y) 
  = singletonMaybe s k (lookup (natFromInt k) t1 >>= flip (f k) y)
intersectionWithKey _ _ Nil _ = Nil
intersectionWithKey _ _ _ Nil = Nil

differenceWithKey :: Sized a -> (Key -> a -> b -> Maybe (a)) -> WordMap a -> WordMap b -> WordMap a
differenceWithKey s f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = difference1
  | shorter m2 m1  = difference2
  | p1 == p2       = bin p1 m1 (differenceWithKey s f l1 l2) (differenceWithKey s f r1 r2)
  | otherwise      = t1
  where
    difference1 | nomatch p2 p1 m1  = t1
                | zero p2 m1        = bin p1 m1 (differenceWithKey s f l1 t2) r1
                | otherwise         = bin p1 m1 l1 (differenceWithKey s f r1 t2)

    difference2 | nomatch p1 p2 m2  = t1
                | zero p1 m2        = differenceWithKey s f t1 l2
                | otherwise         = differenceWithKey s f t1 r2

differenceWithKey s f t1@(Tip _ k x) t2 
  = maybe t1 (singletonMaybe s k . f k x) (lookup (natFromInt k) t2)
differenceWithKey _ _ Nil _       = Nil
differenceWithKey s f t (Tip _ k y) = alter s (>>= flip (f k) y) k t
differenceWithKey _ _ t Nil       = t

isSubmapOfBy :: LEq (a) (b) -> LEq (WordMap a) (WordMap b)
isSubmapOfBy (<=) t1@(Bin _ p1 m1 l1 r1) (Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = False
  | shorter m2 m1  = match p1 p2 m2 && (if zero p1 m2 then isSubmapOfBy (<=) t1 l2
                                                      else isSubmapOfBy (<=) t1 r2)                     
  | otherwise      = (p1==p2) && isSubmapOfBy (<=) l1 l2 && isSubmapOfBy (<=) r1 r2
isSubmapOfBy _         (Bin _ _ _ _ _) _ = False
isSubmapOfBy (<=) (Tip _ k x) t     = maybe False (x <=) (lookup (natFromInt k) t)
isSubmapOfBy _         Nil _           = True

extract :: Alternative f => Sized a -> (Key -> a -> f (CPair x (Maybe a))) -> WordMap a -> f (CPair x (WordMap a))
extract s f t = case t of
	Bin _ p m l r -> fmap (\ l' -> bin p m l' r) <$> extract s f l
				<|> fmap (bin p m l) <$> extract s f r
	Tip _ k x -> fmap (singletonMaybe s k) <$> f k x
	Nil -> empty

maxViewWithKey, minViewWithKey :: Sized a -> (Key -> a -> (x, Maybe a)) -> WordMap a -> Maybe (x, WordMap a)
maxViewWithKey s f t
    = case t of
        Bin _ p m l r         -> let (result, t') = maxViewUnsigned s f r in Just (result, bin p m l t')
        Tip _ k y -> let (result, x) = f k y in Just (result, singletonMaybe s k x)
        Nil -> Nothing

maxViewUnsigned, minViewUnsigned :: Sized a -> (Key -> a -> (x, Maybe a)) -> WordMap a -> (x, WordMap a)
maxViewUnsigned s f t 
    = case t of
        Bin _ p m l r -> let (result,t') = maxViewUnsigned s f r in (result,bin p m l t')
        Tip _ k y -> let (result, x) = f k y in (result, singletonMaybe s k x)
        Nil -> error "maxViewUnsigned Nil"

minViewWithKey s f t
    = case t of
        Bin _ p m l r -> let (result, t') = minViewUnsigned s f l in Just (result, bin p m t' r)
        Tip _ k y -> let (result, x) = f k y in Just (result, singletonMaybe s k x)
        Nil -> Nothing

minViewUnsigned s f t 
    = case t of
        Bin _ p m l r -> let (result,t') = minViewUnsigned s f l in (result,bin p m t' r)
        Tip _ k y -> let (result, x) = f k y in (result, singletonMaybe s k x)
        Nil -> error "minViewUnsigned Nil"

updateMinWithKey :: Sized a -> (Key -> a -> Maybe (a)) -> WordMap a -> WordMap a
updateMinWithKey s f t
    = case t of
        Bin _ p m l r         -> let t' = updateMinWithKeyUnsigned s f l in bin p m t' r
        Tip _ k y -> singletonMaybe s k (f k y)
        Nil -> Nil

updateMinWithKeyUnsigned :: Sized a -> (Key -> a -> Maybe (a)) -> WordMap a -> WordMap a
updateMinWithKeyUnsigned s f t
    = case t of
        Bin _ p m l r -> let t' = updateMinWithKeyUnsigned s f l in bin p m t' r
        Tip _ k y -> singletonMaybe s k (f k y)
        Nil -> Nil

updateMaxWithKey :: Sized a -> (Key -> a -> Maybe (a)) -> WordMap a -> WordMap a
updateMaxWithKey s f t
    = case t of
        Bin _ p m l r         -> let t' = updateMaxWithKeyUnsigned s f r in bin p m l t'
        Tip _ k y -> singletonMaybe s k (f k y)
        Nil -> Nil

updateMaxWithKeyUnsigned :: Sized a -> (Key -> a -> Maybe (a)) -> WordMap a -> WordMap a
updateMaxWithKeyUnsigned s f t
    = case t of
        Bin _ p m l r -> let t' = updateMaxWithKeyUnsigned s f r in bin p m l t'
        Tip _ k y -> singletonMaybe s k (f k y)
        Nil -> Nil

mask :: Key -> Mask -> Prefix
mask i m
  = maskW (natFromInt i) (natFromInt m)

zero :: Key -> Mask -> Bool
zero i m
  = (natFromInt i) .&. (natFromInt m) == 0

nomatch,match :: Key -> Prefix -> Mask -> Bool
nomatch i p m
  = (mask i m) /= p

match i p m
  = (mask i m) == p

zeroN :: Nat -> Nat -> Bool
zeroN i m = (i .&. m) == 0

maskW :: Nat -> Nat -> Prefix
maskW i m
  = intFromNat (i .&. (complement (m-1) `xor` m))

shorter :: Mask -> Mask -> Bool
shorter m1 m2
  = (natFromInt m1) > (natFromInt m2)

branchMask :: Prefix -> Prefix -> Mask
branchMask p1 p2
  = intFromNat (highestBitMask (natFromInt p1 `xor` natFromInt p2))

highestBitMask :: Nat -> Nat
highestBitMask x0
  = case (x0 .|. shiftRL x0 1) of
     x1 -> case (x1 .|. shiftRL x1 2) of
      x2 -> case (x2 .|. shiftRL x2 4) of
       x3 -> case (x3 .|. shiftRL x3 8) of
        x4 -> case (x4 .|. shiftRL x4 16) of
         x5 -> case (x5 .|. shiftRL x5 32) of   -- for 64 bit platforms
          x6 -> (x6 `xor` (shiftRL x6 1))

join :: Prefix -> WordMap a -> Prefix -> WordMap a -> WordMap a
join p1 t1 p2 t2
  | zero p1 m = bin p m t1 t2
  | otherwise = bin p m t2 t1
  where
    m = branchMask p1 p2
    p = mask p1 m

bin :: Prefix -> Mask -> WordMap a -> WordMap a -> WordMap a
bin _ _ l Nil = l
bin _ _ Nil r = r
bin p m l r   = Bin (size l + size r) p m l r