module Data.Repa.Stream.Ratchet ( unsafeRatchetS) where import Control.Monad.Primitive import Data.IORef import Data.Vector.Fusion.Stream.Monadic (Stream(..), Step(..)) import qualified Data.Vector.Generic as GV import qualified Data.Vector.Generic.Mutable as GM import qualified Data.Vector.Fusion.Stream.Size as S #include "repa-stream.h" -- | Interleaved `enumFromTo`. -- -- Given a vector of starting values, and a vector of stopping values, -- produce an stream of elements where we increase each of the starting -- values to the stopping values in a round-robin order. Also produce a -- vector of result segment lengths. -- -- @ -- unsafeRatchetS [10,20,30,40] [15,26,33,47] -- = [10,20,30,40 -- 4 -- ,11,21,31,41 -- 4 -- ,12,22,32,42 -- 4 -- ,13,23 ,43 -- 3 -- ,14,24 ,44 -- 3 -- ,25 ,45 -- 2 -- ,46] -- 1 -- -- ^^^^ ^^^ -- Elements Lengths -- @ -- -- The function takes the starting values in a mutable vector and -- updates it during computation. Computation proceeds by making passes -- through the mutable vector and updating the starting values until -- they match the stopping values. -- -- UNSAFE: Both input vectors must have the same length, -- but this is not checked. -- unsafeRatchetS :: (GM.MVector vm Int, GV.Vector vv Int) => vm (PrimState IO) Int -- ^ Starting values. Overwritten duing computation. -> vv Int -- ^ Ending values -> IORef (vm (PrimState IO) Int) -- ^ Vector holding segment lengths. -> Stream IO Int unsafeRatchetS !mvStarts !vMax !rmvLens = Stream ostep (0, Nothing, 0, 0) S.Unknown where !iSegMax = GM.length mvStarts - 1 ostep (iSeg, mvmLens, oSeg, oLen) = ostep' iSeg mvmLens oSeg oLen {-# INLINE ostep #-} ostep' !iSeg !mvmLens !oSeg !oLen | iSeg <= iSegMax = do !iVal <- GM.unsafeRead mvStarts iSeg let !iNext = vMax `GV.unsafeIndex` iSeg if iVal >= iNext then return $ Skip (iSeg + 1, mvmLens, oSeg, oLen) else do GM.unsafeWrite mvStarts iSeg (iVal + 1) return $ Yield iVal (iSeg + 1, mvmLens, oSeg, oLen + 1) -- We're at the end of an output segment, -- so write the output length into the lengths vector. | oLen > 0 = do -- Get the current output vector. !vmLens <- case mvmLens of Nothing -> readIORef rmvLens Just vmLens -> return $ vmLens -- If the output vector is full then we need to grow it. let !oSegLen = GM.length vmLens if oSeg >= oSegLen then do !vmLens' <- GM.unsafeGrow vmLens (GM.length vmLens) writeIORef rmvLens vmLens' GM.unsafeWrite vmLens' oSeg oLen return $ Skip (0, Just vmLens', oSeg + 1, 0) else do GM.unsafeWrite vmLens oSeg oLen return $ Skip (0, Just vmLens, oSeg + 1, 0) | otherwise = do !vmLens <- case mvmLens of Nothing -> readIORef rmvLens Just vmLens -> return $ vmLens let !vmLens' = GM.unsafeSlice 0 oSeg vmLens writeIORef rmvLens vmLens' return Done {-# INLINE_INNER ostep' #-} {-# INLINE_STREAM unsafeRatchetS #-}