{-# LANGUAGE BangPatterns, UnboxedTuples, TypeFamilies, PatternGuards, MagicHash, CPP, TupleSections, NamedFieldPuns, FlexibleInstances #-} {-# LANGUAGE RecordWildCards, ImplicitParams, GeneralizedNewtypeDeriving, StandaloneDeriving, MultiParamTypeClasses #-} {-# OPTIONS -funbox-strict-fields #-} module Data.TrieMap.OrdMap () where import Control.Monad.Lookup import Data.TrieMap.TrieKey import Data.TrieMap.Sized import Data.TrieMap.Modifiers import Prelude hiding (lookup, foldr, foldl, foldr1, foldl1, map) import GHC.Exts #define DELTA 5 #define RATIO 2 data Path k a = Root | LeftBin k a !(Path k a) !(SNode k a) | RightBin k a !(SNode k a) !(Path k a) data Node k a = Tip | Bin k a !(SNode k a) !(SNode k a) data SNode k a = SNode{sz :: !Int, count :: !Int, node :: Node k a} #define TIP SNode{node=Tip} #define BIN(args) SNode{node=Bin args} -- Morally reprehensible exploitation of generalized newtype deriving. class ImmoralCast a b where immoralCast :: SNode k a -> SNode k b instance ImmoralCast a a where immoralCast = id deriving instance ImmoralCast a (Elem a) instance Sized a => Sized (Node k a) where getSize# m = unbox (case m of Tip -> 0 Bin _ a l r -> getSize a + getSize l + getSize r) instance Sized (SNode k a) where getSize# SNode{sz} = unbox sz nCount :: Node k a -> Int nCount Tip = 0 nCount (Bin _ _ l r) = 1 + count l + count r sNode :: Sized a => Node k a -> SNode k a sNode !n = SNode (getSize n) (nCount n) n tip :: SNode k a tip = SNode 0 0 Tip instance Ord k => Subset (TrieMap (Ordered k)) where OrdMap m1 <=? OrdMap m2 = m1 <=? m2 instance Functor (TrieMap (Ordered k)) where fmap f (OrdMap m) = OrdMap (f <$> m) instance Foldable (TrieMap (Ordered k)) where foldMap f (OrdMap m) = foldMap f m foldr f z (OrdMap m) = foldr f z m foldl f z (OrdMap m) = foldl f z m instance Traversable (TrieMap (Ordered k)) where traverse f (OrdMap m) = OrdMap <$> traverse f m instance Ord k => Buildable (TrieMap (Ordered k)) (Ordered k) where type UStack (TrieMap (Ordered k)) = TrieMap (Ordered k) uFold = defaultUFold emptyM singletonM insertWithM type AStack (TrieMap (Ordered k)) = Distinct (Ordered k) (Stack k) aFold = combineFold daFold type DAStack (TrieMap (Ordered k)) = Stack k daFold = OrdMap <$> mapFoldlKeys unOrd fromDistAscList #define SETOP(op) op f (OrdMap m1) (OrdMap m2) = OrdMap (op f m1 m2) instance Ord k => SetOp (TrieMap (Ordered k)) where SETOP(union) SETOP(isect) SETOP(diff) instance Ord k => Project (TrieMap (Ordered k)) where mapMaybe f (OrdMap m) = OrdMap $ mapMaybe f m mapEither f (OrdMap m) = both OrdMap (mapEither f) m -- | @'TrieMap' ('Ordered' k) a@ is based on "Data.Map". instance Ord k => TrieKey (Ordered k) where newtype TrieMap (Ordered k) a = OrdMap (SNode k a) data Hole (Ordered k) a = Empty k !(Path k a) | Full k !(Path k a) !(SNode k a) !(SNode k a) emptyM = OrdMap tip singletonM (Ord k) a = OrdMap (singleton k a) lookupMC (Ord k) (OrdMap m) = lookupC k m getSimpleM (OrdMap m) = case m of TIP -> Null BIN(_ a TIP TIP) -> Singleton a _ -> NonSimple sizeM (OrdMap m) = sz m singleHoleM (Ord k) = Empty k Root beforeM (Empty _ path) = OrdMap $ before tip path beforeM (Full _ path l _) = OrdMap $ before l path beforeWithM a (Empty k path) = OrdMap $ before (singleton k a) path beforeWithM a (Full k path l _) = OrdMap $ before (insertMax k a l) path afterM (Empty _ path) = OrdMap $ after tip path afterM (Full _ path _ r) = OrdMap $ after r path afterWithM a (Empty k path) = OrdMap $ after (singleton k a) path afterWithM a (Full k path _ r) = OrdMap $ after (insertMin k a r) path searchMC (Ord k) (OrdMap m) = search k m indexM (OrdMap m) i = indexT Root i m where indexT path !i SNode{sz, node = Bin kx x l r} | i <# sl = indexT (LeftBin kx x path r) i l | i <# sx = (# i -# sl, x, Full kx path l r #) | otherwise = indexT (RightBin kx x l path) (i -# sx) r where !sl = getSize# l !sx = unbox $ sz - getSize r indexT _ _ _ = indexFail () extractHoleM (OrdMap m) = extractHole Root m where extractHole path BIN(kx x l r) = extractHole (LeftBin kx x path r) l `mplus` return (x, Full kx path l r) `mplus` extractHole (RightBin kx x l path) r extractHole _ _ = mzero clearM (Empty _ path) = OrdMap $ rebuild tip path clearM (Full _ path l r) = OrdMap $ rebuild (merge l r) path assignM x (Empty k path) = OrdMap $ rebuild (singleton k x) path assignM x (Full k path l r) = OrdMap $ rebuild (join k x l r) path unifierM (Ord k') (Ord k) a = case compare k' k of EQ -> mzero LT -> return $ Empty k' (LeftBin k a Root tip) GT -> return $ Empty k' (RightBin k a tip Root) unifyM (Ord k1) a1 (Ord k2) a2 = case compare k1 k2 of EQ -> mzero LT -> return $ OrdMap $ bin k1 a1 tip (singleton k2 a2) GT -> return $ OrdMap $ bin k1 a1 (singleton k2 a2) tip {-# INLINE insertWithM #-} insertWithM f (Ord k) a (OrdMap m) = OrdMap (insertWith f k a m) insertWith :: (Ord k, Sized a) => (a -> a) -> k -> a -> SNode k a -> SNode k a insertWith f k a = k `seq` ins where ins BIN(kx x l r) = case compare k kx of EQ -> bin kx (f x) l r LT -> balance kx x (ins l) r GT -> balance kx x l (ins r) ins TIP = singleton k a rebuild :: Sized a => SNode k a -> Path k a -> SNode k a rebuild t Root = t rebuild t (LeftBin kx x path r) = rebuild (balance kx x t r) path rebuild t (RightBin kx x l path) = rebuild (balance kx x l t) path lookupC :: Ord k => k -> SNode k a -> Lookup r a lookupC k = look where look BIN(kx x l r) = case compare k kx of LT -> look l EQ -> return x GT -> look r look _ = mzero singleton :: Sized a => k -> a -> SNode k a singleton k a = bin k a tip tip instance Traversable (SNode k) where traverse f = trav where trav TIP = pure tip trav SNode{node = Bin k a l r, ..} = let done a' l' r' = SNode sz count (Bin k a' l' r') in done <$> f a <*> trav l <*> trav r instance Foldable (SNode k) where foldMap _ TIP = mempty foldMap f BIN(_ a l r) = foldMap f l `mappend` f a `mappend` foldMap f r foldr _ z TIP = z foldr f z BIN(_ a l r) = foldr f (a `f` foldr f z r) l foldl _ z TIP = z foldl f z BIN(_ a l r) = foldl f (foldl f z l `f` a) r instance Functor (SNode k) where fmap f = map where map SNode{node = Bin k a l r, ..} = SNode {node = Bin k (f a) (map l) (map r), ..} map _ = tip instance Ord k => Project (SNode k) where mapMaybe f = mMaybe where mMaybe BIN(k a l r) = joinMaybe k (f a) (mMaybe l) (mMaybe r) mMaybe _ = tip mapEither f = mEither where mEither BIN(k a l r) = (# joinMaybe k aL lL rL, joinMaybe k aR lR rR #) where !(# aL, aR #) = f a !(# lL, lR #) = mEither l !(# rL, rR #) = mEither r mEither _ = (# tip, tip #) splitLookup :: Ord k => k -> SNode k (Elem a) -> (SNode k (Elem a) -> Maybe (Elem a) -> SNode k (Elem a) -> r) -> r splitLookup k t cont = search k t (split Nothing) (split . Just) where split v (Empty _ path) = cont (before tip path) v (after tip path) split v (Full _ path l r) = cont (before l path) v (after r path) instance Ord k => Subset (SNode k) where t1 <=? t2 = immoralCast t1 `subMap` immoralCast t2 where TIP `subMap` _ = True _ `subMap` TIP = False BIN(kx x l r) `subMap` t = splitLookup kx t result where result _ Nothing _ = False result tl (Just y) tr = x <=? y && l `subMap` tl && r `subMap` tr fromDistAscList :: (Eq k, Sized a) => Foldl (Stack k) k a (SNode k a) fromDistAscList = Foldl{zero = tip, ..} where incr !t (Yes t' stk) = No (incr (t' `glue` t) stk) incr !t (No stk) = Yes t stk incr !t End = Yes t End begin k a = Yes (singleton k a) End snoc stk k a = incr (singleton k a) stk roll !t End = t roll !t (No stk) = roll t stk roll !t (Yes t' stk) = roll (t' `glue` t) stk done = roll tip data Stack k a = No (Stack k a) | Yes !(SNode k a) (Stack k a) | End instance Ord k => SetOp (SNode k) where union f = hedgeUnion f (const LT) (const GT) diff f = hedgeDiff f (const LT) (const GT) isect f m1 m2 = immoralCast m1 `intersection` m2 where t1@BIN(_ _ _ _) `intersection` BIN(k2 x2 l2 r2) = splitLookup k2 t1 result where result tl found tr = joinMaybe k2 (found >>= \ (Elem x1') -> f x1' x2) (tl `intersection` l2) (tr `intersection` r2) _ `intersection` _ = tip hedgeUnion :: (Ord k, Sized a) => (a -> a -> Maybe a) -> (k -> Ordering) -> (k -> Ordering) -> SNode k a -> SNode k a -> SNode k a hedgeUnion _ _ _ t1 TIP = t1 hedgeUnion _ cmplo cmphi TIP BIN(kx x l r) = join kx x (filterGt cmplo l) (filterLt cmphi r) hedgeUnion f cmplo cmphi BIN(kx x l r) t2 = joinMaybe kx newx (hedgeUnion f cmplo cmpkx l lt) (hedgeUnion f cmpkx cmphi r gt) where cmpkx k = compare kx k lt = trim cmplo cmpkx t2 (found,gt) = trimLookupLo kx cmphi t2 newx = case found of Nothing -> Just x Just (_,y) -> f x y filterGt :: (Ord k, Sized a) => (k -> Ordering) -> SNode k a -> SNode k a filterGt _ TIP = tip filterGt cmp BIN(kx x l r) = case cmp kx of LT -> join kx x (filterGt cmp l) r GT -> filterGt cmp r EQ -> r filterLt :: (Ord k, Sized a) => (k -> Ordering) -> SNode k a -> SNode k a filterLt _ TIP = tip filterLt cmp BIN(kx x l r) = case cmp kx of LT -> filterLt cmp l GT -> join kx x l (filterLt cmp r) EQ -> l trim :: (k -> Ordering) -> (k -> Ordering) -> SNode k a -> SNode k a trim cmplo cmphi = trimmer where trimmer TIP = tip trimmer t@BIN(kx _ l r) = case (cmplo kx, cmphi kx) of (LT, GT) -> t (LT, _) -> trimmer l _ -> trimmer r trimLookupLo :: Ord k => k -> (k -> Ordering) -> SNode k a -> (Maybe (k,a), SNode k a) trimLookupLo _ _ TIP = (Nothing,tip) trimLookupLo lo cmphi t@BIN(kx x l r) = case compare lo kx of LT -> case cmphi kx of GT -> (runLookup (lookupC lo t) Nothing (\ a -> Just (lo, a)), t) _ -> trimLookupLo lo cmphi l GT -> trimLookupLo lo cmphi r EQ -> (Just (kx,x),trim (compare lo) cmphi r) hedgeDiff :: (Ord k, Sized a) => (a -> b -> Maybe a) -> (k -> Ordering) -> (k -> Ordering) -> SNode k a -> SNode k b -> SNode k a hedgeDiff _ _ _ TIP _ = tip hedgeDiff _ cmplo cmphi BIN(kx x l r) TIP = join kx x (filterGt cmplo l) (filterLt cmphi r) hedgeDiff f cmplo cmphi t BIN(kx x l r) = case found of Nothing -> merge tl tr Just (ky,y) -> case f y x of Nothing -> merge tl tr Just z -> join ky z tl tr where cmpkx k = compare kx k lt = trim cmplo cmpkx t (found,gt) = trimLookupLo kx cmphi t tl = hedgeDiff f cmplo cmpkx lt l tr = hedgeDiff f cmpkx cmphi gt r joinMaybe :: (Ord k, Sized a) => k -> Maybe a -> SNode k a -> SNode k a -> SNode k a joinMaybe kx = maybe merge (join kx) join :: Sized a => k -> a -> SNode k a -> SNode k a -> SNode k a join kx x TIP r = insertMin kx x r join kx x l TIP = insertMax kx x l join kx x l@(SNode _ sL (Bin ky y ly ry)) r@(SNode _ sR (Bin kz z lz rz)) | DELTA * sL <= sR = balance kz z (join kx x l lz) rz | DELTA * sR <= sL = balance ky y ly (join kx x ry r) | otherwise = bin kx x l r -- insertMin and insertMax don't perform potentially expensive comparisons. insertMax,insertMin :: Sized a => k -> a -> SNode k a -> SNode k a insertMax kx x = insMax where insMax TIP = singleton kx x insMax BIN(ky y l r) = balance ky y l (insMax r) insertMin kx x = insMin where insMin TIP = singleton kx x insMin BIN(ky y l r) = balance ky y (insMin l) r {-------------------------------------------------------------------- [merge l r]: merges two trees. --------------------------------------------------------------------} merge :: Sized a => SNode k a -> SNode k a -> SNode k a merge TIP r = r merge l TIP = l merge l@(SNode _ sL (Bin kx x lx rx)) r@(SNode _ sR (Bin ky y ly ry)) | DELTA * sL <= sR = balance ky y (merge l ly) ry | DELTA * sR <= sL = balance kx x lx (merge rx r) | otherwise = glue l r {-------------------------------------------------------------------- [glue l r]: glues two trees together. Assumes that [l] and [r] are already balanced with respect to each other. --------------------------------------------------------------------} glue :: Sized a => SNode k a -> SNode k a -> SNode k a glue TIP r = r glue l TIP = l glue l r | count l > count r = let !(# f, l' #) = deleteFindMax balance l in f l' r | otherwise = let !(# f, r' #) = deleteFindMin balance r in f l r' deleteFindMin :: Sized a => (k -> a -> x) -> SNode k a -> (# x, SNode k a #) deleteFindMin f t = case t of BIN(k x TIP r) -> (# f k x, r #) BIN(k x l r) -> onSnd (\ l' -> balance k x l' r) (deleteFindMin f) l _ -> (# error "Map.deleteFindMin: can not return the minimal element of an empty fmap", tip #) deleteFindMax :: Sized a => (k -> a -> x) -> SNode k a -> (# x, SNode k a #) deleteFindMax f t = case t of BIN(k x l TIP) -> (# f k x, l #) BIN(k x l r) -> onSnd (balance k x l) (deleteFindMax f) r TIP -> (# error "Map.deleteFindMax: can not return the maximal element of an empty fmap", tip #) balance :: Sized a => k -> a -> SNode k a -> SNode k a -> SNode k a balance k x l r | sR >= (DELTA * sL) = rotateL k x l r | sL >= (DELTA * sR) = rotateR k x l r | otherwise = bin k x l r where !sL = count l !sR = count r -- rotate rotateL :: Sized a => k -> a -> SNode k a -> SNode k a -> SNode k a rotateL k x l r@BIN(_ _ ly ry) | sL < (RATIO * sR) = singleL k x l r | otherwise = doubleL k x l r where !sL = count ly !sR = count ry rotateL k x l TIP = insertMax k x l rotateR :: Sized a => k -> a -> SNode k a -> SNode k a -> SNode k a rotateR k x l@BIN(_ _ ly ry) r | sR < (RATIO * sL) = singleR k x l r | otherwise = doubleR k x l r where !sL = count ly !sR = count ry rotateR k x TIP r = insertMin k x r -- basic rotations singleL, singleR :: Sized a => k -> a -> SNode k a -> SNode k a -> SNode k a singleL k1 x1 t1 BIN(k2 x2 t2 t3) = bin k2 x2 (bin k1 x1 t1 t2) t3 singleL k1 x1 t1 TIP = bin k1 x1 t1 tip singleR k1 x1 BIN(k2 x2 t1 t2) t3 = bin k2 x2 t1 (bin k1 x1 t2 t3) singleR k1 x1 TIP t2 = bin k1 x1 tip t2 doubleL, doubleR :: Sized a => k -> a -> SNode k a -> SNode k a -> SNode k a doubleL k1 x1 t1 BIN(k2 x2 BIN(k3 x3 t2 t3) t4) = bin k3 x3 (bin k1 x1 t1 t2) (bin k2 x2 t3 t4) doubleL k1 x1 t1 t2 = singleL k1 x1 t1 t2 doubleR k1 x1 BIN(k2 x2 t1 BIN(k3 x3 t2 t3)) t4 = bin k3 x3 (bin k2 x2 t1 t2) (bin k1 x1 t3 t4) doubleR k1 x1 t1 t2 = singleR k1 x1 t1 t2 bin :: Sized a => k -> a -> SNode k a -> SNode k a -> SNode k a bin k x l r = sNode (Bin k x l r) before :: Sized a => SNode k a -> Path k a -> SNode k a before t (LeftBin _ _ path _) = before t path before t (RightBin k a l path) = before (join k a l t) path before t _ = t after :: Sized a => SNode k a -> Path k a -> SNode k a after t (LeftBin k a path r) = after (join k a t r) path after t (RightBin _ _ _ path) = after t path after t _ = t search :: Ord k => k -> SNode k a -> SearchCont (Hole (Ordered k) a) a r search k t f g = searcher Root t where searcher path TIP = f (Empty k path) searcher path BIN(kx x l r) = case compare k kx of LT -> searcher (LeftBin kx x path r) l EQ -> g x (Full k path l r) GT -> searcher (RightBin kx x l path) r