{-# LANGUAGE UnboxedTuples, BangPatterns, TypeFamilies, PatternGuards, MagicHash, CPP #-}

module Data.TrieMap.IntMap () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Sized

import Control.Applicative (Applicative(..), Alternative(..), (<$>))

import Data.Bits
import Data.Maybe hiding (mapMaybe)
import Data.Word

import Prelude hiding (lookup, null, foldl, foldr)

#include "MachDeps.h"
#if WORD_SIZE_IN_BITS == 32
import GHC.Prim
import GHC.Word

complement32 (W32# w#) = W32# (not# w#)
#elif WORD_SIZE_IN_BITS > 32
complement32 = xor (bit 32 - 1)
#else
import GHC.Prim
import GHC.IntWord32
complement32 (W32# w#) = W32# (not32# w#)
#endif
complement32 :: Word32 -> Word32

{-# RULES
	"complement/Word32" complement = complement32
	#-}

type Nat = Word32

type Prefix = Word32
type Mask   = Word32
type Key    = Word32
type Size   = Int

instance TrieKey Word32 where
	data TrieMap Word32 a = Nil
              | Tip {-# UNPACK #-} !Size {-# UNPACK #-} !Key a
              | Bin {-# UNPACK #-} !Size {-# UNPACK #-} !Prefix {-# UNPACK #-} !Mask !(TrieMap Word32 a) !(TrieMap Word32 a) 
	emptyM = Nil
	singletonM = singleton
	nullM = null
	sizeM _ = size
	lookupM = lookup
	alterM = alter
	alterLookupM = alterLookup
	traverseWithKeyM = traverseWithKey
	foldWithKeyM = foldr
	foldlWithKeyM = foldl
	mapMaybeM = mapMaybe
	mapEitherM = mapEither
	splitLookupM = splitLookup
	unionM = unionWithKey
	isectM = intersectionWithKey
	diffM = differenceWithKey
	extractM s f = extract s f
	isSubmapM = isSubmapOfBy

natFromInt :: Word32 -> Nat
natFromInt = id

intFromNat :: Nat -> Word32
intFromNat = id

shiftRL :: Nat -> Key -> Nat
-- #if __GLASGOW_HASKELL__
{--------------------------------------------------------------------
  GHC: use unboxing to get @shiftRL@ inlined.
--------------------------------------------------------------------}
-- shiftRL (W# x) (I# i)
--   = W# (shiftRL# x i)
-- #else
shiftRL x i   = shiftR x (fromIntegral i)
-- #endif


size :: TrieMap Word32 a -> Int
size Nil = 0
size (Tip s _ _) = s
size (Bin s _ _ _ _) = s

null :: TrieMap Word32 a -> Bool
null Nil = True
null _ = False

lookup :: Nat -> TrieMap Word32 a -> Maybe a
lookup k (Bin _ _ m l r) = lookup k (if zeroN k m then l else r)
lookup k (Tip _ kx x)
	| k == kx	= Just x
lookup _ _ = Nothing

singleton :: Sized a -> Key -> a -> TrieMap Word32 a
singleton s k a = Tip (s a) k a

singletonMaybe :: Sized a -> Key -> Maybe a -> TrieMap Word32 a
singletonMaybe s k = maybe Nil (singleton s k)

alter :: Sized a -> (Maybe a -> Maybe a) -> Key -> TrieMap Word32 a -> TrieMap Word32 a
alter s f k t = case t of
	Bin _ p m l r
		| nomatch k p m	-> join k (singletonMaybe s k (f Nothing)) p t
		| zero k m	-> bin p m (alter s f k l) r
		| otherwise	-> bin p m l (alter s f k r)
	Tip _ ky y
		| k == ky	-> singletonMaybe s k (f (Just y))
		| Just x <- f Nothing
				-> join k (Tip (s x) k x) ky t
		| otherwise	-> t
	Nil	-> singletonMaybe s k (f Nothing)

alterLookup :: Sized a -> (Maybe a -> (# x, Maybe a #)) -> Key -> TrieMap Word32 a -> (# x, TrieMap Word32 a #)
alterLookup s f k t = case t of
	Bin _ p m l r
		| nomatch k p m
			-> onUnboxed (\ v -> join k (singletonMaybe s k v) p t) f Nothing
		| zero k m
			-> onUnboxed (\ l' -> bin p m l' r) (alterLookup s f k) l
		| otherwise
			-> onUnboxed (\ r' -> bin p m l r') (alterLookup s f k) r
	Tip _ ky y
		| k == ky	-> onUnboxed (singletonMaybe s k) f (Just y)
		| otherwise	-> onUnboxed (\ v -> join k (singletonMaybe s k v) ky t) f Nothing
	Nil	-> onUnboxed (singletonMaybe s k) f Nothing

traverseWithKey :: Applicative f => Sized b -> (Key -> a -> f b) -> TrieMap Word32 a -> f (TrieMap Word32 b)
traverseWithKey s f t = case t of
	Nil		-> pure Nil
	Tip _ kx x	-> singleton s kx <$> f kx x
	Bin _ p m l r	-> bin p m <$> traverseWithKey s f l <*> traverseWithKey s f r

foldr :: (Key -> a -> b -> b) -> TrieMap Word32 a -> b -> b
foldr f t
  = case t of
      Bin _ _ _ l r -> foldr f l . foldr f r
      Tip _ k x     -> f k x
      Nil         -> id

foldl :: (Key -> b -> a -> b) -> TrieMap Word32 a -> b -> b
foldl f t
  = case t of
      Bin _ _ _ l r -> foldl f r . foldl f l
      Tip _ k x     -> flip (f k) x
      Nil         -> id

mapMaybe :: Sized b -> (Key -> a -> Maybe b) -> TrieMap Word32 a -> TrieMap Word32 b
mapMaybe s f (Bin _ p m l r)	= bin p m (mapMaybe s f l) (mapMaybe s f r)
mapMaybe s f (Tip _ kx x)	= singletonMaybe s kx (f kx x)
mapMaybe _ _ _			= Nil

mapEither :: Sized b -> Sized c -> EitherMap Key a b c ->
	TrieMap Word32 a -> (# TrieMap Word32 b, TrieMap Word32 c #)
mapEither s1 s2 f (Bin _ p m l r) 
	| (# lL, lR #) <- mapEither s1 s2 f l, (# rL, rR #) <- mapEither s1 s2 f r
				= (# bin p m lL rL, bin p m lR rR #)
mapEither s1 s2 f (Tip _ kx x)	= both (singletonMaybe s1 kx) (singletonMaybe s2 kx) (f kx) x
mapEither _ _ _ _		= (# Nil, Nil #)

splitLookup :: Sized a -> SplitMap a x -> Key -> TrieMap Word32 a -> (# TrieMap Word32 a ,Maybe x,TrieMap Word32 a #)
splitLookup s f k t@(Bin _ p m l r)
        | nomatch k p m = if k>p then (# t,Nothing,Nil #) else (# Nil,Nothing,t #)
        | zero k m, (# lt, found, gt #) <- splitLookup s f k l
        		= (# lt,found,union s gt r #)
        | (# lt, found, gt #) <- splitLookup s f k r 
        		= (# union s l lt,found,gt #)
splitLookup s f k t@(Tip _ ky y)
        | k>ky		= (# t,Nothing,Nil #)
        | k<ky		= (# Nil,Nothing,t #)
        | otherwise	= sides (singletonMaybe s k) f y
splitLookup _ _ _ _	= (# Nil,Nothing,Nil #)

union :: Sized a -> TrieMap Word32 a -> TrieMap Word32 a -> TrieMap Word32 a
union _ Nil t       = t
union _ t Nil       = t
union s (Tip _ k x) t = alter s (const (Just x)) k t
union s t (Tip _ k x) = alter s (Just . fromMaybe x) k t  -- right bias
union s t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = union1
  | shorter m2 m1  = union2
  | p1 == p2       = bin p1 m1 (union s l1 l2) (union s r1 r2)
  | otherwise      = join p1 t1 p2 t2
  where
    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
            | zero p2 m1        = bin p1 m1 (union s l1 t2) r1
            | otherwise         = bin p1 m1 l1 (union s r1 t2)

    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
            | zero p1 m2        = bin p2 m2 (union s t1 l2) r2
            | otherwise         = bin p2 m2 l2 (union s t1 r2)

unionWithKey :: Sized a -> UnionFunc Key a -> TrieMap Word32 a -> TrieMap Word32 a -> TrieMap Word32 a
unionWithKey _ _ Nil t  = t
unionWithKey _ _ t Nil  = t
unionWithKey s f (Tip _ k x) t = alter s (maybe (Just x) (f k x)) k t
unionWithKey s f t (Tip _ k x) = alter s (maybe (Just x) (flip (f k) x)) k t
unionWithKey s f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = union1
  | shorter m2 m1  = union2
  | p1 == p2       = bin p1 m1 (unionWithKey s f l1 l2) (unionWithKey s f r1 r2)
  | otherwise      = join p1 t1 p2 t2
  where
    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
            | zero p2 m1        = bin p1 m1 (unionWithKey s f l1 t2) r1
            | otherwise         = bin p1 m1 l1 (unionWithKey s f r1 t2)

    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
            | zero p1 m2        = bin p2 m2 (unionWithKey s f t1 l2) r2
            | otherwise         = bin p2 m2 l2 (unionWithKey s f t1 r2)

intersectionWithKey :: Sized c -> IsectFunc Key a b c -> TrieMap Word32 a -> TrieMap Word32 b -> TrieMap Word32 c
intersectionWithKey _ _ Nil _ = Nil
intersectionWithKey _ _ _ Nil = Nil
intersectionWithKey s f (Tip _ k x) t2
  = singletonMaybe s k (lookup (natFromInt k) t2 >>= f k x)
intersectionWithKey s f t1 (Tip _ k y) 
  = singletonMaybe s k (lookup (natFromInt k) t1 >>= flip (f k) y)
intersectionWithKey s f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = intersection1
  | shorter m2 m1  = intersection2
  | p1 == p2       = bin p1 m1 (intersectionWithKey s f l1 l2) (intersectionWithKey s f r1 r2)
  | otherwise      = Nil
  where
    intersection1 | nomatch p2 p1 m1  = Nil
                  | zero p2 m1        = intersectionWithKey s f l1 t2
                  | otherwise         = intersectionWithKey s f r1 t2

    intersection2 | nomatch p1 p2 m2  = Nil
                  | zero p1 m2        = intersectionWithKey s f t1 l2
                  | otherwise         = intersectionWithKey s f t1 r2

differenceWithKey :: Sized a -> (Key -> a -> b -> Maybe a) -> TrieMap Word32 a -> TrieMap Word32 b -> TrieMap Word32 a
differenceWithKey _ _ Nil _       = Nil
differenceWithKey _ _ t Nil       = t
differenceWithKey s f t1@(Tip _ k x) t2 
  = maybe t1 (singletonMaybe s k . f k x) (lookup (natFromInt k) t2)
differenceWithKey s f t (Tip _ k y) = alter s (>>= flip (f k) y) k t
differenceWithKey s f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = difference1
  | shorter m2 m1  = difference2
  | p1 == p2       = bin p1 m1 (differenceWithKey s f l1 l2) (differenceWithKey s f r1 r2)
  | otherwise      = t1
  where
    difference1 | nomatch p2 p1 m1  = t1
                | zero p2 m1        = bin p1 m1 (differenceWithKey s f l1 t2) r1
                | otherwise         = bin p1 m1 l1 (differenceWithKey s f r1 t2)

    difference2 | nomatch p1 p2 m2  = t1
                | zero p1 m2        = differenceWithKey s f t1 l2
                | otherwise         = differenceWithKey s f t1 r2

isSubmapOfBy :: LEq a b -> LEq (TrieMap Word32 a) (TrieMap Word32 b)
isSubmapOfBy (<=) t1@(Bin _ p1 m1 l1 r1) (Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = False
  | shorter m2 m1  = match p1 p2 m2 && (if zero p1 m2 then isSubmapOfBy (<=) t1 l2
                                                      else isSubmapOfBy (<=) t1 r2)                     
  | otherwise      = (p1==p2) && isSubmapOfBy (<=) l1 l2 && isSubmapOfBy (<=) r1 r2
isSubmapOfBy _		(Bin _ _ _ _ _) _
	= False
isSubmapOfBy (<=)	(Tip _ k x) t
	= maybe False (x <=) (lookup (natFromInt k) t)
isSubmapOfBy _		Nil _
	= True

extract :: Alternative f => Sized a -> (Key -> a -> f (x, Maybe a)) -> TrieMap Word32 a -> f (x, TrieMap Word32 a)
extract s f (Bin _ p m l r)	= 
	fmap (\ l' -> bin p m l' r) <$> extract s f l <|> fmap (bin p m l) <$> extract s f r
extract s f (Tip _ k x)		= fmap (singletonMaybe s k) <$> f k x
extract _ _ _			= empty

mask :: Key -> Mask -> Prefix
mask i m
  = maskW (natFromInt i) (natFromInt m)

zero :: Key -> Mask -> Bool
zero i m
  = (natFromInt i) .&. (natFromInt m) == 0

nomatch,match :: Key -> Prefix -> Mask -> Bool
nomatch i p m
  = (mask i m) /= p

match i p m
  = (mask i m) == p

zeroN :: Nat -> Nat -> Bool
zeroN i m = (i .&. m) == 0

maskW :: Nat -> Nat -> Prefix
maskW i m
  = intFromNat (i .&. (complement (m-1) `xor` m))

shorter :: Mask -> Mask -> Bool
shorter m1 m2
  = (natFromInt m1) > (natFromInt m2)

branchMask :: Prefix -> Prefix -> Mask
branchMask p1 p2
  = intFromNat (highestBitMask (natFromInt p1 `xor` natFromInt p2))

highestBitMask :: Nat -> Nat
highestBitMask x0
  = case (x0 .|. shiftRL x0 1) of
     x1 -> case (x1 .|. shiftRL x1 2) of
      x2 -> case (x2 .|. shiftRL x2 4) of
       x3 -> case (x3 .|. shiftRL x3 8) of
        x4 -> case (x4 .|. shiftRL x4 16) of
#if WORD_SIZE_IN_BITS > 32
         x5 -> case (x5 .|. shiftRL x5 32) of   -- for 64 bit platforms
          x6 -> (x6 `xor` (shiftRL x6 1))
#else
	 x5 -> x5 `xor` shiftRL x5 1
#endif

join :: Prefix -> TrieMap Word32 a -> Prefix -> TrieMap Word32 a -> TrieMap Word32 a
join p1 t1 p2 t2
  | zero p1 m = bin p m t1 t2
  | otherwise = bin p m t2 t1
  where
    m = branchMask p1 p2
    p = mask p1 m

bin :: Prefix -> Mask -> TrieMap Word32 a -> TrieMap Word32 a -> TrieMap Word32 a
bin _ _ l Nil = l
bin _ _ Nil r = r
bin p m l r   = Bin (size l + size r) p m l r