{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE BangPatterns #-} {-# OPTIONS_GHC -Wall -Werror -fno-warn-unused-imports #-} module BTree.Linear ( BTree , Context(..) , lookup , insert , modifyWithM , new , foldrWithKey , toAscList , fromList , debugMap ) where import Prelude hiding (lookup) import Data.Primitive.MutVar import Control.Monad import Data.Foldable (foldlM) import Data.Primitive (MutableArray,Prim) import qualified Data.Primitive as P import Data.Primitive.PrimArray import Control.Monad.Primitive data Context s = Context { contextDegree :: {-# UNPACK #-} !Int } data BTree s k v = BTree !(MutVar s Int) -- current number of keys in this node !(MutablePrimArray s k) !(Contents s k v) data Contents s k v = ContentsValues !(MutablePrimArray s v) | ContentsNodes !(MutableArray s (BTree s k v)) new :: (PrimMonad m, Prim k, Prim v) => Context (PrimState m) -- ^ Max number of children per node -> m (BTree (PrimState m) k v) new (Context degree) = do if degree < 3 then error "Btree.new: max nodes per child cannot be less than 3" else return () szRef <- newMutVar 0 keys <- newPrimArray (degree - 1) values <- newPrimArray (degree - 1) return (BTree szRef keys (ContentsValues values)) {-# INLINABLE lookup #-} lookup :: forall m k v. (PrimMonad m, Ord k, Prim k, Prim v) => Context (PrimState m) -> BTree (PrimState m) k v -> k -> m (Maybe v) lookup (Context _) theNode k = go theNode where go :: BTree (PrimState m) k v -> m (Maybe v) go (BTree szRef keys c) = do sz <- readMutVar szRef case c of ContentsValues values -> do e <- findIndex keys k sz case e of Left _ -> return Nothing Right ix -> do v <- readPrimArray values ix return (Just v) ContentsNodes nodes -> do ix <- findIndexBetween keys k sz go =<< P.readArray nodes ix data Insert s k v = Ok !v | Split !(BTree s k v) !k !v -- ^ The new node that will go to the right, -- the key propagated to the parent, -- the inserted value. uninitializedNode :: a uninitializedNode = error "unitializedNode: this should not be forced, b+ tree implementation has a mistake." {-# INLINE insert #-} insert :: (PrimMonad m, Ord k, Prim k, Prim v) => Context (PrimState m) -> BTree (PrimState m) k v -> k -> v -> m (BTree (PrimState m) k v) insert ctx m k v = do (_,node) <- modifyWithM ctx m k (\_ -> return v) return node -- | This is provided for completeness but is not something -- typically useful in producetion code. toAscList :: forall m k v. (PrimMonad m, Ord k, Prim k, Prim v) => Context (PrimState m) -> BTree (PrimState m) k v -> m [(k,v)] toAscList = foldrWithKey f [] where f :: k -> v -> [(k,v)] -> m [(k,v)] f k v xs = return ((k,v) : xs) fromList :: (PrimMonad m, Ord k, Prim k, Prim v) => Context (PrimState m) -> [(k,v)] -> m (BTree (PrimState m) k v) fromList ctx xs = do root0 <- new ctx foldlM (\root (k,v) -> do insert ctx root k v ) root0 xs foldrWithKey :: forall m k v b. (PrimMonad m, Ord k, Prim k, Prim v) => (k -> v -> b -> m b) -> b -> Context (PrimState m) -> BTree (PrimState m) k v -> m b foldrWithKey f b0 (Context _) root = flip go b0 root where go :: BTree (PrimState m) k v -> b -> m b go (BTree szRef keys c) b = do sz <- readMutVar szRef case c of ContentsValues values -> foldrPrimArrayPairs sz f b keys values ContentsNodes nodes -> foldrArray (sz + 1) go b nodes foldrArray :: forall m a b. (PrimMonad m) => Int -- ^ length of array -> (a -> b -> m b) -> b -> MutableArray (PrimState m) a -> m b foldrArray len f b0 arr = go (len - 1) b0 where go :: Int -> b -> m b go !ix !b1 = if ix >= 0 then do a <- P.readArray arr ix b2 <- f a b1 go (ix - 1) b2 else return b1 foldrPrimArrayPairs :: forall m k v b. (PrimMonad m, Ord k, Prim k, Prim v) => Int -- ^ length of arrays -> (k -> v -> b -> m b) -> b -> MutablePrimArray (PrimState m) k -> MutablePrimArray (PrimState m) v -> m b foldrPrimArrayPairs len f b0 ks vs = go (len - 1) b0 where go :: Int -> b -> m b go !ix !b1 = if ix >= 0 then do k <- readPrimArray ks ix v <- readPrimArray vs ix b2 <- f k v b1 go (ix - 1) b2 else return b1 {-# SPECIALIZE modifyWithM :: Context RealWorld -> BTree RealWorld Int Int -> Int -> (Maybe Int -> IO Int) -> IO (Int, BTree RealWorld Int Int) #-} {-# INLINABLE modifyWithM #-} modifyWithM :: forall m s k v. (PrimMonad m, Ord k, Prim k, Prim v) => Context s -> BTree (PrimState m) k v -> k -> (Maybe v -> m v) -> m (v, BTree (PrimState m) k v) modifyWithM (Context degree) root k alter = do ins <- go root case ins of Ok v -> return (v,root) Split rightNode newRootKey v -> do let leftNode = root newRootSz <- newMutVar 1 newRootKeys <- newPrimArray (degree - 1) writePrimArray newRootKeys 0 newRootKey newRootChildren <- P.newArray degree uninitializedNode P.writeArray newRootChildren 0 leftNode P.writeArray newRootChildren 1 rightNode let newRoot = BTree newRootSz newRootKeys (ContentsNodes newRootChildren) return (v,newRoot) where go :: BTree (PrimState m) k v -> m (Insert (PrimState m) k v) go (BTree szRef keys c) = do sz <- readMutVar szRef case c of ContentsValues values -> do e <- findIndex keys k sz case e of Left gtIx -> do v <- alter Nothing if sz < degree - 1 then do -- We have enough space writeMutVar szRef (sz + 1) unsafeInsertPrimArray sz gtIx k keys unsafeInsertPrimArray sz gtIx v values return (Ok v) else do -- We do not have enough space. The node must be split. let leftSize = div sz 2 rightSize = sz - leftSize leftKeys = keys leftValues = values if gtIx < leftSize then do rightKeys <- newPrimArray (degree - 1) rightValues <- newPrimArray (degree - 1) rightSzRef <- newMutVar rightSize copyMutablePrimArray rightKeys 0 leftKeys leftSize rightSize copyMutablePrimArray rightValues 0 leftValues leftSize rightSize unsafeInsertPrimArray leftSize gtIx k leftKeys unsafeInsertPrimArray leftSize gtIx v leftValues propagated <- readPrimArray rightKeys 0 writeMutVar szRef (leftSize + 1) return (Split (BTree rightSzRef rightKeys (ContentsValues rightValues)) propagated v) else do rightKeys <- newPrimArray (degree - 1) rightValues <- newPrimArray (degree - 1) rightSzRef <- newMutVar (rightSize + 1) -- Currently, we're copying from left to right and -- then doing another copy from right to right. We -- might be able to do better. We could do the same number -- of memcpys but copy fewer total elements and not -- have the slowdown caused by overlap. copyMutablePrimArray rightKeys 0 leftKeys leftSize rightSize copyMutablePrimArray rightValues 0 leftValues leftSize rightSize unsafeInsertPrimArray rightSize (gtIx - leftSize) k rightKeys unsafeInsertPrimArray rightSize (gtIx - leftSize) v rightValues propagated <- readPrimArray rightKeys 0 writeMutVar szRef leftSize return (Split (BTree rightSzRef rightKeys (ContentsValues rightValues)) propagated v) Right ix -> do v <- readPrimArray values ix v' <- alter (Just v) writePrimArray values ix v' return (Ok v') ContentsNodes nodes -> do (gtIx,isEq) <- findIndexGte keys k sz -- case e of -- Right _ -> error "write Right case" -- Left gtIx -> do node <- P.readArray nodes (if isEq then gtIx + 1 else gtIx) ins <- go node case ins of Ok v -> return (Ok v) Split rightNode propagated v -> if sz < degree - 1 then do unsafeInsertPrimArray sz gtIx propagated keys unsafeInsertArray (sz + 1) (gtIx + 1) rightNode nodes writeMutVar szRef (sz + 1) return (Ok v) else do let middleIx = div sz 2 leftKeys = keys leftNodes = nodes middleKey <- readPrimArray keys middleIx rightKeys :: MutablePrimArray (PrimState m) k <- newPrimArray (degree - 1) rightNodes <- P.newArray degree uninitializedNode rightSzRef <- newMutVar 0 -- this always gets replaced let leftSize = middleIx rightSize = sz - leftSize if middleIx >= gtIx then do copyMutablePrimArray rightKeys 0 leftKeys (leftSize + 1) (rightSize - 1) P.copyMutableArray rightNodes 0 leftNodes (leftSize + 1) rightSize unsafeInsertPrimArray leftSize gtIx propagated leftKeys unsafeInsertArray (leftSize + 1) (gtIx + 1) rightNode leftNodes writeMutVar szRef (leftSize + 1) writeMutVar rightSzRef (rightSize - 1) else do -- Currently, we're copying from left to right and -- then doing another copy from right to right. We can do better. -- There is a similar note further up. copyMutablePrimArray rightKeys 0 leftKeys (leftSize + 1) (rightSize - 1) P.copyMutableArray rightNodes 0 leftNodes (leftSize + 1) rightSize unsafeInsertPrimArray (rightSize - 1) (gtIx - leftSize - 1) propagated rightKeys unsafeInsertArray rightSize (gtIx - leftSize) rightNode rightNodes writeMutVar szRef leftSize writeMutVar rightSzRef rightSize return (Split (BTree rightSzRef rightKeys (ContentsNodes rightNodes)) middleKey v) -- Preconditions: -- * marr is sorted low to high -- * sz is less than or equal to the true size of marr -- The returned value is in the inclusive range [0,sz] findIndexBetween :: forall m a. (PrimMonad m, Ord a, Prim a) => MutablePrimArray (PrimState m) a -> a -> Int -> m Int findIndexBetween !marr !needle !sz = go 0 where go :: Int -> m Int go !i = if i < sz then do a <- readPrimArray marr i if a > needle then return i else go (i + 1) else return i -- i should be equal to sz -- Preconditions: -- * marr is sorted low to high -- * sz is less than or equal to the true size of marr -- The returned value is either -- * in the inclusive range [0,sz - 1] -- * the value (-1), indicating that no match was found findIndex :: forall m a. (PrimMonad m, Ord a, Prim a) => MutablePrimArray (PrimState m) a -> a -> Int -> m (Either Int Int) findIndex !marr !needle !sz = go 0 where go :: Int -> m (Either Int Int) go !i = if i < sz then do a <- readPrimArray marr i case compare a needle of LT -> go (i + 1) EQ -> return (Right i) GT -> return (Left i) else return (Left i) -- | The second value in the tuple is true when -- the index match was exact. findIndexGte :: forall m a. (PrimMonad m, Ord a, Prim a) => MutablePrimArray (PrimState m) a -> a -> Int -> m (Int,Bool) findIndexGte !marr !needle !sz = go 0 where go :: Int -> m (Int,Bool) go !i = if i < sz then do a <- readPrimArray marr i case compare a needle of LT -> go (i + 1) EQ -> return (i,True) GT -> return (i,False) else return (i,False) -- | Insert an element in the array, shifting the values right -- of the index. The array size should be big enough for this -- shift, this is not checked. unsafeInsertArray :: (PrimMonad m) => Int -- ^ Size of the original array -> Int -- ^ Index -> a -- ^ Value -> MutableArray (PrimState m) a -- ^ Array to modify -> m () unsafeInsertArray sz i x marr = do P.copyMutableArray marr (i + 1) marr i (sz - i) P.writeArray marr i x -- Inserts a value at the designated index, -- shifting everything after it to the right. -- -- Example: -- ----------------------------- -- | a | b | c | d | e | X | X | -- ----------------------------- -- unsafeInsertPrimArray 5 3 'k' marr -- unsafeInsertPrimArray :: (PrimMonad m, Prim a) => Int -- ^ Size of the original array -> Int -- ^ Index -> a -- ^ Value -> MutablePrimArray (PrimState m) a -- ^ Array to modify -> m () unsafeInsertPrimArray sz i x marr = do copyMutablePrimArray marr (i + 1) marr i (sz - i) writePrimArray marr i x showPairs :: forall m k v. (PrimMonad m, Show k, Show v, Prim k, Prim v) => Int -- size -> MutablePrimArray (PrimState m) k -> MutablePrimArray (PrimState m) v -> m [String] showPairs sz keys values = go 0 where go :: Int -> m [String] go ix = if ix < sz then do k <- readPrimArray keys ix v <- readPrimArray values ix let str = show k ++ ": " ++ show v strs <- go (ix + 1) return (str : strs) else return [] -- | Show the internal structure of a Map, useful for debugging, not exported debugMap :: forall m k v. (PrimMonad m, Prim k, Prim v, Show k, Show v) => Context (PrimState m) -> BTree (PrimState m) k v -> m String debugMap (Context _) (BTree rootSzRef rootKeys rootContents) = do rootSz <- readMutVar rootSzRef let go :: Int -> Int -> MutablePrimArray (PrimState m) k -> Contents (PrimState m) k v -> m [(Int,String)] go level sz keys c = case c of ContentsValues values -> do pairStrs <- showPairs sz keys values return (map (\s -> (level,s)) pairStrs) ContentsNodes nodes -> do pairs <- pairForM sz keys nodes $ \k (BTree nextSzRef nextKeys nextContents) -> do nextSz <- readMutVar nextSzRef nextStrs <- go (level + 1) nextSz nextKeys nextContents return (nextStrs ++ [(level,show k)]) -- ++ " (Size: " ++ show nextSz ++ ")")]) -- I think this should always end up being in bounds BTree lastSzRef lastKeys lastContents <- P.readArray nodes sz lastSz <- readMutVar lastSzRef lastStrs <- go (level + 1) lastSz lastKeys lastContents -- return (nextStrs ++ [(level,show k)]) return ([(level, "start")] ++ concat pairs ++ lastStrs) allStrs <- go 0 rootSz rootKeys rootContents return $ unlines $ map (\(level,str) -> replicate (level * 2) ' ' ++ str) ((0,"root size: " ++ show rootSz) : allStrs) pairForM :: forall m a b c. (PrimMonad m, Prim a) => Int -> MutablePrimArray (PrimState m) a -> MutableArray (PrimState m) c -> (a -> c -> m b) -> m [b] pairForM sz marr1 marr2 f = go 0 where go :: Int -> m [b] go ix = if ix < sz then do a <- readPrimArray marr1 ix c <- P.readArray marr2 ix b <- f a c bs <- go (ix + 1) return (b : bs) else return []