{-# LANGUAGE UnboxedTuples, BangPatterns, TypeFamilies, PatternGuards, MagicHash, CPP, NamedFieldPuns, FlexibleInstances, RecordWildCards #-} {-# LANGUAGE TemplateHaskell, TypeOperators, MultiParamTypeClasses #-} {-# OPTIONS -funbox-strict-fields -O -fspec-constr -fliberate-case -fstatic-argument-transformation #-} module Data.TrieMap.WordMap (SNode, WHole, TrieMap(WordMap), Hole(Hole), WordStack, getWordMap, getHole) where import Data.TrieMap.TrieKey import Data.TrieMap.Sized import Control.Exception (assert) import Control.Monad.Lookup import Control.Monad.Unpack import Data.Bits import Data.Maybe hiding (mapMaybe) import GHC.Exts import Prelude hiding (lookup, null, map, foldl, foldr, foldl1, foldr1) #include "MachDeps.h" #define NIL SNode{node = Nil} #define TIP(args) SNode{node = (Tip args)} #define BIN(args) SNode{node = (Bin args)} type Nat = Word type Prefix = Word type Mask = Word type Key = Word type Size = Int data Path a = Root | LeftBin !Prefix !Mask (Path a) !(SNode a) | RightBin !Prefix !Mask !(SNode a) (Path a) data SNode a = SNode {sz :: !Size, node :: (Node a)} {-# ANN type SNode ForceSpecConstr #-} data Node a = Nil | Tip !Key a | Bin !Prefix !Mask !(SNode a) !(SNode a) instance Sized (SNode a) where getSize# SNode{sz} = unbox sz instance Sized a => Sized (Node a) where getSize# t = unbox (case t of Nil -> 0 Tip _ a -> getSize a Bin _ _ l r -> getSize l + getSize r) {-# INLINE sNode #-} sNode :: Sized a => Node a -> SNode a sNode !n = SNode (getSize n) n data WHole a = WHole !Key (Path a) {-# ANN type WHole ForceSpecConstr #-} $(noUnpackInstance ''Path) $(noUnpackInstance ''Node) $(unpackInstance ''WHole) $(unpackInstance ''SNode) {-# INLINE hole #-} hole :: Key -> Path a -> Hole Word a hole k path = Hole (WHole k path) #define HOLE(args) (Hole (WHole args)) instance Subset (TrieMap Word) where WordMap m1 <=? WordMap m2 = m1 <=? m2 instance Functor (TrieMap Word) where fmap f (WordMap m) = WordMap (f <$> m) instance Foldable (TrieMap Word) where foldMap f (WordMap m) = foldMap f m foldr f z (WordMap m) = foldr f z m foldl f z (WordMap m) = foldl f z m foldr1 f (WordMap m) = foldr1 f m foldl1 f (WordMap m) = foldl1 f m instance Traversable (TrieMap Word) where traverse f (WordMap m) = WordMap <$> traverse f m instance Buildable (TrieMap Word) Word where type UStack (TrieMap Word) = SNode {-# INLINE uFold #-} uFold = fmap WordMap . defaultUFold nil singleton (\ f k a -> insertWithC f k (getSize a) a) type AStack (TrieMap Word) = WordStack {-# INLINE aFold #-} aFold = fmap WordMap . fromAscList type DAStack (TrieMap Word) = WordStack {-# INLINE daFold #-} daFold = aFold const #define SETOP(op) op f (WordMap m1) (WordMap m2) = WordMap (op f m1 m2) instance SetOp (TrieMap Word) where SETOP(union) SETOP(isect) SETOP(diff) instance Project (TrieMap Word) where mapMaybe f (WordMap m) = WordMap $ mapMaybe f m mapEither f (WordMap m) = both WordMap (mapEither f) m -- | @'TrieMap' 'Word' a@ is based on "Data.IntMap". instance TrieKey Word where newtype TrieMap Word a = WordMap {getWordMap :: SNode a} newtype Hole Word a = Hole {getHole :: WHole a} emptyM = WordMap nil singletonM k a = WordMap (singleton k a) getSimpleM (WordMap (SNode _ n)) = case n of Nil -> Null Tip _ a -> Singleton a _ -> NonSimple sizeM (WordMap t) = getSize t lookupMC k (WordMap m) = lookupC k m singleHoleM k = hole k Root beforeM HOLE(_ path) = WordMap (before path) beforeWithM a HOLE(k path) = WordMap (beforeWith (singleton k a) path) afterM HOLE(_ path) = WordMap (after path) afterWithM a HOLE(k path) = WordMap (afterWith (singleton k a) path) {-# INLINE searchMC #-} searchMC !k (WordMap t) notfound found = searchC k t (unpack (notfound . Hole)) (\ a -> unpack (found a . Hole)) {-# INLINE indexM #-} indexM (WordMap m) i = index i m extractHoleM (WordMap m) = extractHole Root m where extractHole _ (SNode _ Nil) = mzero extractHole path TIP(kx x) = return (x, hole kx path) extractHole path BIN(p m l r) = extractHole (LeftBin p m path r) l `mplus` extractHole (RightBin p m l path) r {-# INLINE clearM #-} clearM HOLE(_ path) = case clear path of (# sz#, node #) -> WordMap SNode{sz = I# sz#, node} {-# INLINE assignM #-} assignM v HOLE(kx path) = case assign (singleton kx v) path of (# sz#, node #) -> WordMap SNode{sz = I# sz#, node} {-# INLINE unifyM #-} unifyM k1 a1 k2 a2 = WordMap <$> unify k1 a1 k2 a2 {-# INLINE unifierM #-} unifierM k' k a = Hole <$> unifier k' k a {-# INLINE insertWithM #-} insertWithM f k a (WordMap m) = WordMap (insertWithC f k (getSize a) a m) insertWithC :: Sized a => (a -> a) -> Key -> Int -> a -> SNode a -> SNode a insertWithC f !k !szA a !t = ins' t where {-# INLINE tip #-} tip = SNode {sz = szA, node = Tip k a} {-# INLINE out #-} out SNode{sz = I# sz#, node} = (# sz#, node #) {-# INLINE ins' #-} ins' t = case ins t of (# sz#, node #) -> SNode{sz = I# sz#, node} ins !t = case t of BIN(p m l r) | nomatch k p m -> out $ join k tip p t | mask0 k m -> out $ bin' p m (ins' l) r | otherwise -> out $ bin' p m l (ins' r) TIP(kx x) | k == kx -> out $ singleton kx (f x) | otherwise -> out $ join k tip kx t NIL -> out tip index :: Int# -> SNode a -> (# Int#, a, Hole Word a #) index i !t = indexT i t Root where indexT i TIP(kx x) path = (# i, x, hole kx path #) indexT i BIN(p m l r) path | i <# sl = indexT i l (LeftBin p m path r) | otherwise = indexT (i -# sl) r (RightBin p m l path) where !sl = getSize# l indexT _ NIL _ = indexFail () searchC :: Key -> SNode a -> (WHole a :~> r) -> (a -> WHole a :~> r) -> r searchC !k t notfound found = seek Root t where seek path t@BIN(p m l r) | nomatch k p m = notfound $~ WHole k (branchHole k p path t) | mask0 k m = seek (LeftBin p m path r) l | otherwise = seek (RightBin p m l path) r seek path t@TIP(ky y) | k == ky = found y $~ WHole k path | otherwise = notfound $~ WHole k (branchHole k ky path t) seek path NIL = notfound $~ WHole k path before, after :: Path a -> SNode a beforeWith, afterWith :: SNode a -> Path a -> SNode a before Root = nil before (LeftBin _ _ path _) = before path before (RightBin _ _ l path) = beforeWith l path beforeWith !t Root = t beforeWith !t (LeftBin _ _ path _) = beforeWith t path beforeWith !t (RightBin p m l path) = beforeWith (bin' p m l t) path after Root = nil after (RightBin _ _ _ path) = after path after (LeftBin _ _ path r) = afterWith r path afterWith !t Root = t afterWith !t (RightBin _ _ _ path) = afterWith t path afterWith !t (LeftBin p m path r) = afterWith (bin' p m t r) path clear :: Path a -> (# Int#, Node a #) assign :: SNode a -> Path a -> (# Int#, Node a #) clear Root = (# 0#, Nil #) clear (LeftBin _ _ path r) = assign r path clear (RightBin _ _ l path) = assign l path assign SNode{sz = I# sz#, node} Root = (# sz#, node #) assign !t (LeftBin p m path r) = assign (bin' p m t r) path assign !t (RightBin p m l path) = assign (bin' p m l t) path branchHole :: Key -> Prefix -> Path a -> SNode a -> Path a branchHole !k !p path t | mask0 k m = LeftBin p' m path t | otherwise = RightBin p' m t path where m = branchMask k p p' = mask k m {-# INLINE lookupC #-} lookupC :: Key -> SNode a -> Lookup r a lookupC !k !t = Lookup $ \ no yes -> let look BIN(_ m l r) = if zeroN k m then look l else look r look TIP(kx x) | k == kx = yes x look _ = no in look t singleton :: Sized a => Key -> a -> SNode a singleton k a = sNode (Tip k a) singletonMaybe :: Sized a => Key -> Maybe a -> SNode a singletonMaybe k = maybe nil (singleton k) instance Functor SNode where fmap f = map where map SNode{sz, node} = SNode sz $ case node of Nil -> Nil Tip k x -> Tip k (f x) Bin p m l r -> Bin p m (map l) (map r) instance Foldable SNode where foldMap f = fold where fold NIL = mempty fold TIP(_ x) = f x fold BIN(_ _ l r) = fold l `mappend` fold r foldr f = flip fold where fold BIN(_ _ l r) z = fold l (fold r z) fold TIP(_ x) z = f x z fold NIL z = z foldl f = fold where fold z BIN(_ _ l r) = fold (fold z l) r fold z TIP(_ x) = f z x fold z NIL = z instance Traversable SNode where traverse f = trav where trav NIL = pure nil trav SNode{sz, node = Tip kx x} = SNode sz . Tip kx <$> f x trav SNode{sz, node = Bin p m l r} = SNode sz .: Bin p m <$> trav l <*> trav r instance Subset SNode where (<=?) = subMap where t1@BIN(p1 m1 l1 r1) `subMap` BIN(p2 m2 l2 r2) | shorter m1 m2 = False | shorter m2 m1 = match p1 p2 m2 && (if mask0 p1 m2 then t1 `subMap` l2 else t1 `subMap` r2) | otherwise = (p1==p2) && l1 `subMap` l2 && r1 `subMap` r2 BIN({}) `subMap` _ = False TIP(k x) `subMap` t2 = runLookup (lookupC k t2) False (x n2 (_, Nil) -> n1 (Tip k x, _) -> alter (maybe (Just x) (f x)) k n2 (_, Tip k x) -> alter (maybe (Just x) (`f` x)) k n1 (Bin p1 m1 l1 r1, Bin p2 m2 l2 r2) | shorter m1 m2 -> union1 | shorter m2 m1 -> union2 | p1 == p2 -> bin p1 m1 (l1 \/ l2) (r1 \/ r2) | otherwise -> join p1 n1 p2 n2 where union1 | nomatch p2 p1 m1 = join p1 n1 p2 n2 | mask0 p2 m1 = bin p1 m1 (l1 \/ n2) r1 | otherwise = bin p1 m1 l1 (r1 \/ n2) union2 | nomatch p1 p2 m2 = join p1 n1 p2 n2 | mask0 p1 m2 = bin p2 m2 (n1 \/ l2) r2 | otherwise = bin p2 m2 l2 (n1 \/ r2) isect f = (/\) where n1@(SNode _ t1) /\ n2@(SNode _ t2) = case (t1, t2) of (Nil, _) -> nil (Tip{}, Nil) -> nil (Bin{}, Nil) -> nil (Tip k x, _) -> runLookup (lookupC k n2) nil (singletonMaybe k . f x) (_, Tip k y) -> runLookup (lookupC k n1) nil (singletonMaybe k . flip f y) (Bin p1 m1 l1 r1, Bin p2 m2 l2 r2) | shorter m1 m2 -> intersection1 | shorter m2 m1 -> intersection2 | p1 == p2 -> bin p1 m1 (l1 /\ l2) (r1 /\ r2) | otherwise -> nil where intersection1 | nomatch p2 p1 m1 = nil | mask0 p2 m1 = l1 /\ n2 | otherwise = r1 /\ n2 intersection2 | nomatch p1 p2 m2 = nil | mask0 p1 m2 = n1 /\ l2 | otherwise = n1 /\ r2 diff f = (\\) where n1@(SNode _ t1) \\ n2@(SNode _ t2) = case (t1, t2) of (Nil, _) -> nil (_, Nil) -> n1 (Tip k x, _) -> runLookup (lookupC k n2) n1 (singletonMaybe k . f x) (_, Tip k y) -> alter (>>= flip f y) k n1 (Bin p1 m1 l1 r1, Bin p2 m2 l2 r2) | shorter m1 m2 -> difference1 | shorter m2 m1 -> difference2 | p1 == p2 -> bin p1 m1 (l1 \\ l2) (r1 \\ r2) | otherwise -> n1 where difference1 | nomatch p2 p1 m1 = n1 | mask0 p2 m1 = bin p1 m1 (l1 \\ n2) r1 | otherwise = bin p1 m1 l1 (r1 \\ n2) difference2 | nomatch p1 p2 m2 = n1 | mask0 p1 m2 = n1 \\ l2 | otherwise = n1 \\ r2 instance Project SNode where mapMaybe f = mMaybe where mMaybe BIN(p m l r) = bin p m (mMaybe l) (mMaybe r) mMaybe TIP(kx x) = singletonMaybe kx (f x) mMaybe NIL = nil mapEither f = mEither where mEither BIN(p m l r) = (# bin p m l1 r1, bin p m l2 r2 #) where !(# l1, l2 #) = mEither l !(# r1, r2 #) = mEither r mEither TIP(kx x) = both (singletonMaybe kx) f x mEither NIL = (# nil, nil #) {-# INLINE alter #-} alter :: Sized a => (Maybe a -> Maybe a) -> Key -> SNode a -> SNode a alter f k t = getWordMap $ alterM f k (WordMap t) mask0 :: Key -> Mask -> Bool mask0 i m = i .&. m == 0 nomatch,match :: Key -> Prefix -> Mask -> Bool nomatch i p m = (mask i m) /= p match i p m = (mask i m) == p zeroN :: Nat -> Nat -> Bool zeroN i m = (i .&. m) == 0 mask :: Nat -> Nat -> Prefix mask i m = i .&. compl ((m-1) `xor` m) shorter :: Mask -> Mask -> Bool shorter m1 m2 = m1 > m2 branchMask :: Prefix -> Prefix -> Mask branchMask p1 p2 = highestBitMask (p1 `xor` p2) highestBitMask :: Nat -> Nat highestBitMask x0 = case (x0 .|. shiftR x0 1) of x1 -> case (x1 .|. shiftR x1 2) of x2 -> case (x2 .|. shiftR x2 4) of x3 -> case (x3 .|. shiftR x3 8) of x4 -> case (x4 .|. shiftR x4 16) of x5 -> case (x5 .|. shiftR x5 32) of -- for 64 bit platforms x6 -> (x6 `xor` (shiftR x6 1)) {-# INLINE join #-} join :: Prefix -> SNode a -> Prefix -> SNode a -> SNode a join p1 t1 p2 t2 | mask0 p1 m = SNode{sz = sz', node = Bin p m t1 t2} | otherwise = SNode{sz = sz', node = Bin p m t2 t1} where m = branchMask p1 p2 p = mask p1 m sz' = sz t1 + sz t2 nil :: SNode a nil = SNode 0 Nil bin :: Prefix -> Mask -> SNode a -> SNode a -> SNode a bin p m l@(SNode sl tl) r@(SNode sr tr) = case (tl, tr) of (Nil, _) -> r (_, Nil) -> l _ -> SNode (sl + sr) (Bin p m l r) bin' :: Prefix -> Mask -> SNode a -> SNode a -> SNode a bin' p m l@SNode{sz=sl} r@SNode{sz=sr} = assert (nonempty l && nonempty r) $ SNode (sl + sr) (Bin p m l r) where nonempty NIL = False nonempty _ = True {-# INLINE unify #-} unify :: Sized a => Key -> a -> Key -> a -> Lookup r (SNode a) unify k1 a1 k2 a2 = Lookup $ \ no yes -> if k1 == k2 then no else yes (join k1 (singleton k1 a1) k2 (singleton k2 a2)) {-# INLINE unifier #-} unifier :: Sized a => Key -> Key -> a -> Lookup r (WHole a) unifier k' k a = Lookup $ \ no yes -> if k == k' then no else yes (WHole k' $ branchHole k' k Root (singleton k a)) {-# INLINE fromAscList #-} fromAscList :: Sized a => (a -> a -> a) -> Foldl WordStack Key a (SNode a) fromAscList f = Foldl{zero = nil, ..} where begin kx vx = WordStack kx vx Nada snoc (WordStack kx vx stk) kz vz | kx == kz = WordStack kx (f vz vx) stk | otherwise = WordStack kz vz $ reduce (branchMask kx kz) kx (singleton kx vx) stk -- reduce :: Mask -> Prefix -> SNode a -> Stack a -> Stack a reduce !m !px !tx (Push py ty stk') | shorter m mxy = reduce m pxy (bin' pxy mxy ty tx) stk' where mxy = branchMask px py; pxy = mask px mxy reduce _ px tx stk = Push px tx stk done (WordStack kx vx stk) = case finish kx (singleton kx vx) stk of (# sz#, node #) -> SNode {sz = I# sz#, node} finish !px !tx (Push py ty stk) = finish p (join' py ty px tx) stk where m = branchMask px py; p = mask px m finish _ SNode{sz, node} Nada = (# unbox sz, node #) join' p1 t1 p2 t2 = SNode{sz = sz t1 + sz t2, node = Bin p m t1 t2} where m = branchMask p1 p2 p = mask p1 m data WordStack a = WordStack !Key a (Stack a) data Stack a = Push !Prefix !(SNode a) !(Stack a) | Nada