{-# language ExistentialQuantification, TypeFamilies, FlexibleInstances, MultiParamTypeClasses #-} module Data.Sparse.Internal.Utils where import Control.Monad (unless) import Control.Monad.State import Control.Monad.ST import Data.Ord (comparing) import qualified Data.Vector as V -- import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Mutable as VM import qualified Data.Vector.Algorithms.Merge as VA -- import Data.Ix -- import Data.Maybe -- import Data.Sparse.Types -- import Numeric.LinearAlgebra.Class -- | Given a number of rows(resp. columns) `n` and a _sorted_ Vector of Integers in increasing order (containing the row(col) indices of nonzero entries), return the cumulative vector of nonzero entries of length `n + 1` (the "row(col) pointer" of the CSR(CSC) format). NB: Fused count-and-accumulate -- E.g.: -- > csrPtrV (==) 4 (V.fromList [1,1,2,3]) -- [0,0,2,3,4] csrPtrV :: (a -> Int -> Bool) -> Int -> V.Vector a -> V.Vector Int csrPtrV eqf n xs = V.create createf where createf :: ST s (VM.MVector s Int) createf = do let c = 0 vm <- VM.new (n + 1) VM.write vm 0 0 -- write `0` at position 0 let loop v ll i count | i == n = return () | otherwise = do let lp = V.length $ V.takeWhile (`eqf` i) ll count' = count + lp VM.write v (i + 1) count' loop v (V.drop lp ll) (succ i) count' loop vm xs 0 c return vm -- csrPtrV' eqf n xs = V.create createf where -- createf :: ST s (VM.MVector s Int) -- createf = do -- let c = 0 -- vm <- VM.new (n + 1) -- VM.write vm 0 0 -- numLoop fw 0 where -- fw ix = let lp = V.length $ V.takeWhile (`eqf` ix) ll -- count -- return vm numLoop :: Monad m => (Int -> m a) -> Int -> m () numLoop fm n = go 0 where go i | i == n = return () | otherwise = do _ <- fm i go (succ i) numLoopST' :: Monad m => (Int -> s -> m a) -> Int -> (s -> s) -> s -> m () numLoopST' fm n fs s0 = go 0 s0 where go i s | i == n = return () | otherwise = do _ <- fm i s go (succ i) (fs s) numLoopST'' :: MonadState s m => (Int -> s -> m a) -> Int -> (s -> s) -> m () numLoopST'' fm n fs = go 0 where go i = unless (i == n) $ do s <- get _ <- fm i s -- ignore result of `fm` put $ fs s go (succ i) -- | O(N) : Intersection between sorted vectors, in-place updates intersectWith :: Ord o => (a -> o) -> (a -> a -> c) -> V.Vector a -> V.Vector a -> V.Vector c intersectWith f = intersectWithCompare (comparing f) intersectWithCompare :: (a1 -> a2 -> Ordering) -> (a1 -> a2 -> a) -> V.Vector a1 -> V.Vector a2 -> V.Vector a intersectWithCompare fcomp g u_ v_ = V.create $ do let n = min (V.length u_) (V.length v_) vm <- VM.new n let go u_ v_ i vm | V.null u_ || V.null v_ || i == n = return (vm, i) | otherwise = do let (u,us) = (V.head u_, V.tail u_) (v,vs) = (V.head v_, V.tail v_) case fcomp u v of EQ -> do VM.write vm i (g u v) go us vs (i + 1) vm LT -> go us v_ i vm GT -> go u_ vs i vm (vm', i') <- go u_ v_ 0 vm let vm'' = VM.take i' vm' return vm'' -- | O(N) : Union between sorted vectors, in-place updates unionWith :: Ord o => (t -> o) -> (t -> t -> a) -> t -> V.Vector t -> V.Vector t -> V.Vector a unionWith f = unionWithCompare (comparing f) unionWithCompare :: (t -> t -> Ordering) -> (t -> t -> a) -> t -> V.Vector t -> V.Vector t -> V.Vector a unionWithCompare fcomp g z u_ v_ = V.create $ do let n = (V.length u_) + (V.length v_) vm <- VM.new n let go u_ v_ i vm | (V.null u_ && V.null v_) || i==n = return (vm, i) | V.null u_ = do VM.write vm i (g z (V.head v_)) go u_ (V.tail v_) (i+1) vm | V.null v_ = do VM.write vm i (g (V.head u_) z) go (V.tail u_) v_ (i+1) vm | otherwise = do let (u,us) = (V.head u_, V.tail u_) (v,vs) = (V.head v_, V.tail v_) case fcomp u v of EQ -> do VM.write vm i (g u v) go us vs (i + 1) vm LT -> do VM.write vm i (g u z) go us v_ (i + 1) vm GT -> do VM.write vm i (g z v) go u_ vs (i + 1) vm (vm', nfin) <- go u_ v_ 0 vm let vm'' = VM.take nfin vm' return vm'' -- * Sorting sortWith :: Ord o => (t -> o) -> V.Vector t -> V.Vector t sortWith f v = V.modify (VA.sortBy (comparing f)) v sortWith3 :: Ord o => ((a, b, c) -> o) -> V.Vector a -> V.Vector b -> V.Vector c -> V.Vector (a, b, c) sortWith3 f x y z = sortWith f $ V.zip3 x y z sortByFst3 :: Ord a => V.Vector a -> V.Vector b -> V.Vector c -> V.Vector (a, b, c) sortByFst3 = sortWith3 fst3 sortBySnd3 :: Ord b => V.Vector a -> V.Vector b -> V.Vector c -> V.Vector (a, b, c) sortBySnd3 = sortWith3 snd3 -- * Utilities -- ** 3-tuples fst3 :: (a, b, c) -> a fst3 (a, _, _) = a snd3 :: (a, b, c) -> b snd3 (_, b, _) = b third3 :: (a, b, c) -> c third3 (_, _, c) = c tail3 :: (t, t1, t2) -> (t1, t2) tail3 (_,j,x) = (j,x) mapFst3 :: (a -> b) -> (a, y, z) -> (b, y, z) mapFst3 f (a, b, c) = (f a, b, c) mapSnd3 :: (a -> b) -> (x, a, z) -> (x, b, z) mapSnd3 f (a, b, c) = (a, f b, c) mapThird3 :: (a -> b) -> (x, y, a) -> (x, y, b) mapThird3 f (a, b, c) = (a, b, f c) lift2 :: (a -> b) -> (b -> b -> c) -> a -> a -> c lift2 p f t1 t2 = f (p t1) (p t2) -- | Stream fusion based version of the above, from [1] -- [1] : https://www.schoolofhaskell.com/user/edwardk/revisiting-matrix-multiplication/part-3 data Stream m a = forall s . Stream (s -> m (Step s a)) s data Step s a = Yield a s | Skip s | Done data MergeState sa sb i a = MergeL sa sb i a | MergeR sa sb i a | MergeLeftEnded sb | MergeRightEnded sa | MergeStart sa sb mergeStreamsWith :: (Ord i, Monad m) => (a -> a -> Maybe a) -> Stream m (i, a) -> Stream m (i, a) -> Stream m (i, a) mergeStreamsWith f (Stream stepa sa0) (Stream stepb sb0) = Stream step (MergeStart sa0 sb0) where step (MergeStart sa sb) = do r <- stepa sa return $ case r of Yield (i, a) sa' -> Skip (MergeL sa' sb i a) Skip sa' -> Skip (MergeStart sa' sb) Done -> Skip (MergeLeftEnded sb) step (MergeL sa sb i a) = do r <- stepb sb return $ case r of Yield (j, b) sb' -> case compare i j of LT -> Yield (i, a) (MergeR sa sb' j b) EQ -> case f a b of Just c -> Yield (i, c) (MergeStart sa sb') Nothing -> Skip (MergeStart sa sb') GT -> Yield (j, b) (MergeL sa sb' i a) Skip sb' -> Skip (MergeL sa sb' i a) Done -> Yield (i, a) (MergeRightEnded sa) step (MergeR sa sb j b) = do r <- stepa sa return $ case r of Yield (i, a) sa' -> case compare i j of LT -> Yield (i, a) (MergeR sa' sb j b) EQ -> case f a b of Just c -> Yield (i, c) (MergeStart sa' sb) Nothing -> Skip (MergeStart sa' sb) GT -> Yield (j, b) (MergeL sa' sb i a) Skip sa' -> Skip (MergeR sa' sb j b) Done -> Yield (j, b) (MergeLeftEnded sb) step (MergeLeftEnded sb) = do r <- stepb sb return $ case r of Yield (j, b) sb' -> Yield (j, b) (MergeLeftEnded sb') Skip sb' -> Skip (MergeLeftEnded sb') Done -> Done step (MergeRightEnded sa) = do r <- stepa sa return $ case r of Yield (i, a) sa' -> Yield (i, a) (MergeRightEnded sa') Skip sa' -> Skip (MergeRightEnded sa') Done -> Done {-# INLINE [0] step #-} {-# INLINE [1] mergeStreamsWith #-} -- test data -- m0 = V.fromList [O (0,0,1), O(0,1,2)] -- m1 = V.fromList [O (0,0,1), O(0,2,3)] isOrderedV :: Ord a => V.Vector a -> Bool isOrderedV l = V.all (== True) $ V.zipWith (<) l (V.tail l)