{-# LANGUAGE DataKinds #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} -- | Internals of 'TDigest'. -- -- Tree implementation is based on /Adams’ Trees Revisited/ by Milan Straka -- module Data.TDigest.Tree.Internal where import Control.DeepSeq (NFData (..)) import Control.Monad.ST (ST, runST) import Data.Binary (Binary (..)) import Data.Either (isRight) import Data.Foldable (toList) import Data.List.Compat (foldl') import Data.List.NonEmpty (nonEmpty) import Data.Ord (comparing) import Data.Proxy (Proxy (..)) import Data.Semigroup (Semigroup (..)) import Data.Semigroup.Reducer (Reducer (..)) import GHC.TypeLits (KnownNat, Nat, natVal) import Prelude () import Prelude.Compat import qualified Data.Vector.Algorithms.Heap as VHeap import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as MVU import Data.TDigest.Internal import qualified Data.TDigest.Postprocess.Internal as PP ------------------------------------------------------------------------------- -- TDigest ------------------------------------------------------------------------------- -- | 'TDigest' is a tree of centroids. -- -- @compression@ is a @1/δ@. The greater the value of @compression@ the less -- likely value merging will happen. data TDigest (compression :: Nat) -- | Tree node = Node {-# UNPACK #-} !Size -- size of this tree/centroid {-# UNPACK #-} !Mean -- mean of the centroid {-# UNPACK #-} !Weight -- weight of the centrod {-# UNPACK #-} !Weight -- total weight of the tree !(TDigest compression) -- left subtree !(TDigest compression) -- right subtree -- | Empty tree | Nil deriving (Show) -- [Note: keep min & max in the tree] -- -- We tried it, but it seems the alloc/update cost is bigger than -- re-calculating them on need (it's O(log n) - calculation!) -- [Note: singleton node] -- We tried to add one, but haven't seen change in performance -- [Note: inlining balanceR and balanceL] -- We probably can squueze some performance by making -- 'balanceL' and 'balanceR' check arguments only once (like @containers@ do) -- and not use 'node' function. -- *But*, the benefit vs. code explosion is not yet worth. instance KnownNat comp => Semigroup (TDigest comp) where (<>) = combineDigest -- | Both 'cons' and 'snoc' are 'insert' instance KnownNat comp => Reducer Double (TDigest comp) where cons = insert snoc = flip insert unit = singleton instance KnownNat comp => Monoid (TDigest comp) where mempty = emptyTDigest mappend = combineDigest -- | 'TDigest' has only strict fields. instance NFData (TDigest comp) where rnf x = x `seq` () -- | 'TDigest' isn't compressed after de-serialisation, -- but it can be still smaller. instance KnownNat comp => Binary (TDigest comp) where put = put . getCentroids get = foldl' (flip insertCentroid) emptyTDigest . lc <$> get where lc :: [Centroid] -> [Centroid] lc = id instance PP.HasHistogram (TDigest comp) Maybe where histogram = fmap PP.histogramFromCentroids . nonEmpty . getCentroids totalWeight = totalWeight getCentroids :: TDigest comp -> [Centroid] getCentroids = ($ []) . go where go Nil = id go (Node _ x w _ l r) = go l . ((x,w) : ) . go r -- | Total count of samples. -- -- >>> totalWeight (tdigest [1..100] :: TDigest 5) -- 100.0 -- totalWeight :: TDigest comp -> Weight totalWeight Nil = 0 totalWeight (Node _ _ _ tw _ _) = tw size :: TDigest comp -> Int size Nil = 0 size (Node s _ _ _ _ _) = s -- | Center of left-most centroid. Note: may be different than min element inserted. -- -- >>> minimumValue (tdigest [1..100] :: TDigest 3) -- 1.0 -- minimumValue :: TDigest comp -> Mean minimumValue = go posInf where go acc Nil = acc go _acc (Node _ x _ _ l _) = go x l -- | Center of right-most centroid. Note: may be different than max element inserted. -- -- >>> maximumValue (tdigest [1..100] :: TDigest 3) -- 99.0 -- maximumValue :: TDigest comp -> Mean maximumValue = go negInf where go acc Nil = acc go _acc (Node _ x _ _ _ r) = go x r ------------------------------------------------------------------------------- -- Implementation ------------------------------------------------------------------------------- emptyTDigest :: TDigest comp emptyTDigest = Nil combineDigest :: KnownNat comp => TDigest comp -> TDigest comp -> TDigest comp combineDigest a Nil = a combineDigest Nil b = b combineDigest a@(Node n _ _ _ _ _) b@(Node m _ _ _ _ _) -- TODO: merge first, then shuffle and insert (part of compress) | n < m = compress $ foldl' (flip insertCentroid) b (getCentroids a) | otherwise = compress $ foldl' (flip insertCentroid) a (getCentroids b) insertCentroid :: forall comp. KnownNat comp => Centroid -> TDigest comp -> TDigest comp insertCentroid (x, w) Nil = singNode x w insertCentroid (mean, weight) td = go 0 mean weight False td where -- New weight of the tree n :: Weight n = totalWeight td + weight -- 1/delta compression :: Double compression = fromInteger $ natVal (Proxy :: Proxy comp) go :: Weight -- weight to the left of this tree -> Mean -- mean to insert -> Weight -- weight to insert -> Bool -- should insert everything. -- if we merged somewhere on top, rest is inserted as is -> TDigest comp -- subtree to insert/merge centroid into -> TDigest comp go _ newX newW _ Nil = singNode newX newW go cum newX newW e (Node s x w tw l r) = case compare newX x of -- Exact match, insert here EQ -> Node s x (w + newW) (tw + newW) l r -- node x (w + newW) l r -- there is *no* room to insert into this node LT | thr <= w -> balanceL x w (go cum newX newW e l) r GT | thr <= w -> balanceR x w l (go (cum + totalWeight l + w) newX newW e r) -- otherwise go left ... or later right LT | e -> balanceL x w (go cum newX newW e l) r LT -> case l of -- always create a new node Nil -> case mrw of Nothing -> node' s nx nw (tw + newW) Nil r Just rw -> balanceL nx nw (go cum newX rw True Nil) r Node {} | lmax < newX && abs (newX - x) < abs (newX - lmax) {- && newX < x -} -> case mrw of Nothing -> node' s nx nw (tw + nw - w) l r -- in this two last LT cases, we have to recalculate size Just rw -> balanceL nx nw (go cum newX rw True l) r | otherwise -> balanceL x w (go cum newX newW e l) r where lmax = maximumValue l -- ... or right GT | e -> balanceR x w l (go (cum + totalWeight l + w) newX newW True r) GT -> case r of Nil -> case mrw of Nothing -> node' s nx nw (tw + newW) l Nil Just rw -> balanceR nx nw l (go (cum + totalWeight l + nw) newX rw True Nil) Node {} | rmin > newX && abs (newX - x) < abs (newX - rmin) {- && newX > x -} -> case mrw of Nothing -> node' s nx nw (tw + newW) l r -- in this two last GT cases, we have to recalculate size Just rw -> balanceR nx nw l (go (cum + totalWeight l + nw) newX rw True r) | otherwise -> balanceR x w l (go (cum + totalWeight l + w) newX newW e r) where rmin = minimumValue r where -- quantile approximation of current node cum' = cum + totalWeight l q = (w / 2 + cum') / n -- threshold, max size of current node/centroid thr = {- traceShowId $ traceShow (n, q) $ -} threshold n q compression -- We later use nx, nw and mrw: -- max size of current node dw :: Weight mrw :: Maybe Weight (dw, mrw) = let diff = assert (thr > w) "threshold should be larger than current node weight" $ w + newW - thr in if diff < 0 -- i.e. there is room then (newW, Nothing) else (thr - w, Just diff) -- the change of current node (nx, nw) = {- traceShowId $ traceShow (newX, newW, x, dw, mrw) $ -} combinedCentroid x w x dw -- | Constructor which calculates size and total weight. node :: Mean -> Weight -> TDigest comp -> TDigest comp -> TDigest comp node x w l r = Node (1 + size l + size r) x w (w + totalWeight l + totalWeight r) l r -- | Balance after right insertion. balanceR :: Mean -> Weight -> TDigest comp -> TDigest comp -> TDigest comp balanceR x w l r | size l + size r <= 1 = node x w l r | size r > balOmega * size l = case r of Nil -> error "balanceR: impossible happened" (Node _ rx rw _ Nil rr) -> -- assert (0 < balAlpha * size rr) "balanceR" $ -- single left rotation node rx rw (node x w l Nil) rr (Node _ rx rw _ rl rr) | size rl < balAlpha * size rr -> -- single left rotation node rx rw (node x w l rl) rr (Node _ rx rw _ (Node _ rlx rlw _ rll rlr) rr) -> -- double left rotation node rlx rlw (node x w l rll) (node rx rw rlr rr) | otherwise = node x w l r -- | Balance after left insertion. balanceL :: Mean -> Weight -> TDigest comp -> TDigest comp -> TDigest comp balanceL x w l r | size l + size r <= 1 = node x w l r | size l > balOmega * size r = case l of Nil -> error "balanceL: impossible happened" (Node _ lx lw _ ll Nil) -> -- assert (0 < balAlpha * size ll) "balanceL" $ -- single right rotation node lx lw ll (node x w Nil r) (Node _ lx lw _ ll lr) | size lr < balAlpha * size ll -> -- single right rotation node lx lw ll (node x w lr r) (Node _ lx lw _ ll (Node _ lrx lrw _ lrl lrr)) -> -- double left rotation node lrx lrw (node lx lw ll lrl) (node x w lrr r) | otherwise = node x w l r -- | Alias to 'Node' node' :: Int -> Mean -> Weight -> Weight -> TDigest comp -> TDigest comp -> TDigest comp node' = Node -- | Create singular node. singNode :: Mean -> Weight -> TDigest comp singNode x w = Node 1 x w w Nil Nil -- | Add two weighted means together. combinedCentroid :: Mean -> Weight -> Mean -> Weight -> Centroid combinedCentroid x w x' w' = ( (x * w + x' * w') / w'' -- this is probably not num. stable , w'' ) where w'' = w + w' -- | Calculate the threshold, i.e. maximum weight of centroid. threshold :: Double -- ^ total weight -> Double -- ^ quantile -> Double -- ^ compression (1/δ) -> Double threshold n q compression = 4 * n * q * (1 - q) / compression ------------------------------------------------------------------------------- -- Compression ------------------------------------------------------------------------------- -- | Compress 'TDigest'. -- -- Reinsert the centroids in "better" order (in original paper: in random) -- so they have opportunity to merge. -- -- Compression will happen only if size is both: -- bigger than @'relMaxSize' * comp@ and bigger than 'absMaxSize'. -- compress :: forall comp. KnownNat comp => TDigest comp -> TDigest comp compress Nil = Nil compress td | size td > relMaxSize * compression && size td > absMaxSize = forceCompress td | otherwise = td where compression = fromInteger $ natVal (Proxy :: Proxy comp) -- | Perform compression, even if current size says it's not necessary. forceCompress :: forall comp. KnownNat comp => TDigest comp -> TDigest comp forceCompress Nil = Nil forceCompress td = foldl' (flip insertCentroid) emptyTDigest $ fmap fst $ VU.toList centroids where -- Centroids are shuffled based on space centroids :: VU.Vector (Centroid, Double) centroids = runST $ do v <- toMVector td -- sort by cumulative weight VHeap.sortBy (comparing snd) v VU.unsafeFreeze v toMVector :: forall comp s. KnownNat comp => TDigest comp -- ^ t-Digest -> ST s (VU.MVector s (Centroid, Double)) -- ^ return also a "space left in the centroid" value for "shuffling" toMVector td = do v <- MVU.new (size td) (i, cum) <- go v (0 :: Int) (0 :: Double) td pure $ assert (i == size td && abs (cum - totalWeight td) < 1e-6) "traversal in toMVector:" v where go _ i cum Nil = pure (i, cum) go v i cum (Node _ x w _ l r) = do (i', cum') <- go v i cum l MVU.unsafeWrite v i' ((x, w), space w cum') go v (i' + 1) (cum' + w) r n = totalWeight td compression = fromInteger $ natVal (Proxy :: Proxy comp) space w cum = thr - w where q = (w / 2 + cum) / n thr = threshold n q compression ------------------------------------------------------------------------------- -- Params ------------------------------------------------------------------------------- -- | Relative size parameter. Hard-coded value: 25. relMaxSize :: Int relMaxSize = 25 -- | Absolute size parameter. Hard-coded value: 1000. absMaxSize :: Int absMaxSize = 1000 ------------------------------------------------------------------------------- -- Tree balance parameters ------------------------------------------------------------------------------- balOmega :: Int balOmega = 3 balAlpha :: Int balAlpha = 2 -- balDelta = 0 ------------------------------------------------------------------------------- -- Debug ------------------------------------------------------------------------------- -- | Output the 'TDigest' tree. debugPrint :: TDigest comp -> IO () debugPrint td = go 0 td where go i Nil = putStrLn $ replicate (i * 3) ' ' ++ "Nil" go i (Node s m w tw l r) = do go (i + 1) l putStrLn $ replicate (i * 3) ' ' ++ "Node " ++ show (s,m,w,tw) go (i + 1) r -- | @'isRight' . 'validate'@ valid :: TDigest comp -> Bool valid = isRight . validate -- | Check various invariants in the 'TDigest' tree. validate :: TDigest comp -> Either String (TDigest comp) validate td | not (all sizeValid centroids) = Left "invalid sizes" | not (all weightValid centroids) = Left "invalid weights" | not (all orderValid centroids) = Left "invalid ordering" | not (all balanced centroids) = Left "tree is ill-balanced" | otherwise = Right td where centroids = goc td goc Nil = [] goc n@(Node _ _ _ _ l r) = n : goc l ++ goc r sizeValid Nil = True sizeValid (Node s _ _ _ l r) = s == size l + size r + 1 weightValid Nil = True weightValid (Node _ _ w tw l r) = eq tw $ w + totalWeight l + totalWeight r orderValid Nil = True orderValid (Node _ _ _ _ Nil Nil) = True orderValid (Node _ x _ _ (Node _ lx _ _ _ _) Nil) = lx < x orderValid (Node _ x _ _ Nil (Node _ rx _ _ _ _)) = x < rx orderValid (Node _ x _ _ (Node _ lx _ _ _ _) (Node _ rx _ _ _ _)) = lx < x && x < rx balanced Nil = True balanced (Node _ _ _ _ l r) = size l <= max 1 (balOmega * size r) && size r <= max 1 (balOmega * size l) ------------------------------------------------------------------------------- -- Higher level helpers ------------------------------------------------------------------------------- -- | Insert single value into 'TDigest'. insert :: KnownNat comp => Double -- ^ element -> TDigest comp -> TDigest comp insert x = compress . insert' x -- | Insert single value, don't compress 'TDigest' even if needed. -- -- For sensibly bounded input, it makes sense to let 'TDigest' grow (it might -- grow linearly in size), and after that compress it once. insert' :: KnownNat comp => Double -- ^ element -> TDigest comp -> TDigest comp insert' x = insertCentroid (x, 1) -- | Make a 'TDigest' of a single data point. singleton :: KnownNat comp => Double -> TDigest comp singleton x = insert x emptyTDigest -- | Strict 'foldl'' over 'Foldable' structure. tdigest :: (Foldable f, KnownNat comp) => f Double -> TDigest comp tdigest = foldl' insertChunk emptyTDigest . chunks . toList where -- compress after each chunk, forceCompress at the very end. insertChunk td xs = compress (foldl' (flip insert') td xs) chunks [] = [] chunks xs = let (a, b) = splitAt 1000 xs -- 1000 is totally arbitrary. in a : chunks b -- $setup -- >>> :set -XDataKinds