module STMContainers.HAMT.Nodes where import STMContainers.Prelude hiding (insert, lookup, delete, foldM, null) import qualified STMContainers.Prelude as Prelude import qualified STMContainers.WordArray as WordArray import qualified STMContainers.SizedArray as SizedArray import qualified STMContainers.HAMT.Level as Level import qualified Focus import qualified ListT type Nodes e = TVar (WordArray.WordArray (Node e)) data Node e = Nodes {-# UNPACK #-} !(Nodes e) | Leaf {-# UNPACK #-} !Hash !e | Leaves {-# UNPACK #-} !Hash {-# UNPACK #-} !(SizedArray.SizedArray e) type Hash = Int class (Eq (ElementKey e)) => Element e where type ElementKey e elementKey :: e -> ElementKey e {-# INLINE new #-} new :: STM (Nodes e) new = newTVar WordArray.empty {-# INLINE newIO #-} newIO :: IO (Nodes e) newIO = newTVarIO WordArray.empty insert :: (Element e) => e -> Hash -> ElementKey e -> Level.Level -> Nodes e -> STM () insert e h k l ns = do a <- readTVar ns let write n = writeTVar ns $ WordArray.set i n a case WordArray.lookup i a of Nothing -> write (Leaf h e) Just n -> case n of Nodes ns' -> insert e h k (Level.succ l) ns' Leaf h' e' -> if h' == h then if elementKey e' == k then write (Leaf h e) else write (Leaves h (SizedArray.pair e e')) else do nodes <- pair h (Leaf h e) h' (Leaf h' e') (Level.succ l) write (Nodes nodes) Leaves h' la -> if h' == h then case SizedArray.find ((== k) . elementKey) la of Just (lai, _) -> write (Leaves h' (SizedArray.insert lai e la)) Nothing -> write (Leaves h' (SizedArray.append e la)) else write . Nodes =<< pair h (Leaf h e) h' (Leaves h' la) (Level.succ l) where i = Level.hashIndex l h pair :: Hash -> Node e -> Hash -> Node e -> Level.Level -> STM (Nodes e) pair h1 n1 h2 n2 l = if i1 == i2 then newTVar . WordArray.singleton i1 . Nodes =<< pair h1 n1 h2 n2 (Level.succ l) else newTVar $ WordArray.pair i1 n1 i2 n2 where hashIndex = Level.hashIndex l i1 = hashIndex h1 i2 = hashIndex h2 focus :: (Element e) => Focus.StrategyM STM e r -> Hash -> ElementKey e -> Level.Level -> Nodes e -> STM r focus s h k l ns = do a <- readTVar ns (r, a'm) <- WordArray.focusM s' ai a maybe (return ()) (writeTVar ns) a'm return r where ai = Level.hashIndex l h s' = \case Nothing -> traversePair (return . fmap (Leaf h)) =<< s Nothing Just n -> case n of Nodes ns' -> do r <- focus s h k (Level.succ l) ns' null ns' >>= \case True -> return (r, Focus.Remove) False -> return (r, Focus.Keep) Leaf h' e' -> case h' == h of True -> case elementKey e' == k of True -> traversePair (return . fmap (Leaf h)) =<< s (Just e') False -> traversePair processDecision =<< s Nothing where processDecision = \case Focus.Replace e -> return (Focus.Replace (Leaves h (SizedArray.pair e e'))) _ -> return Focus.Keep False -> traversePair processDecision =<< s Nothing where processDecision = \case Focus.Replace e -> do ns' <- pair h (Leaf h e) h' (Leaf h' e') (Level.succ l) return (Focus.Replace (Nodes ns')) _ -> return Focus.Keep Leaves h' a' -> case h' == h of True -> case SizedArray.find ((== k) . elementKey) a' of Just (i', e') -> s (Just e') >>= traversePair processDecision where processDecision = \case Focus.Keep -> return Focus.Keep Focus.Remove -> case SizedArray.delete i' a' of a'' -> case SizedArray.null a'' of False -> return (Focus.Replace (Leaves h' a'')) True -> return Focus.Remove Focus.Replace e -> return (Focus.Replace (Leaves h' (SizedArray.insert i' e a'))) Nothing -> s Nothing >>= traversePair processDecision where processDecision = \case Focus.Replace e -> return (Focus.Replace (Leaves h' (SizedArray.append e a'))) _ -> return Focus.Keep False -> s Nothing >>= traversePair processDecision where processDecision = \case Focus.Replace e -> do ns' <- pair h (Leaf h e) h' (Leaves h' a') (Level.succ l) return (Focus.Replace (Nodes ns')) _ -> return Focus.Keep null :: Nodes e -> STM Bool null = fmap WordArray.null . readTVar foldM :: (a -> e -> STM a) -> a -> Level.Level -> Nodes e -> STM a foldM step acc level = readTVar >=> foldlM step' acc where step' acc' = \case Nodes ns -> foldM step acc' (Level.succ level) ns Leaf _ e -> step acc' e Leaves _ a -> SizedArray.foldM step acc' a stream :: Level.Level -> Nodes e -> ListT.ListT STM e stream l = lift . readTVar >=> ListT.fromFoldable >=> \case Nodes n -> stream (Level.succ l) n Leaf _ e -> return e Leaves _ a -> ListT.fromFoldable a size :: Nodes e -> STM Int size nodes = readTVar nodes >>= foldlM step 0 where step a = fmap (a+) . nodeSize where nodeSize :: Node e -> STM Int nodeSize = \case Nodes nodes -> size nodes Leaf _ _ -> pure 1 Leaves _ x -> pure (SizedArray.size x)