module Data.Repa.Chain.Folds (foldsC, Folds (..)) where import Data.Repa.Chain.Base import Data.Repa.Scalar.Option import Data.Vector.Fusion.Stream.Size as S #include "repa-stream.h" -- | Segmented fold over vectors of segment lengths and input values. -- -- The total lengths of all segments need not match the length of the -- input elements vector. The returned `C.Folds` state can be inspected -- to determine whether all segments were completely folded, or the -- vector of segment lengths or elements was too short relative to the -- other. -- foldsC :: Monad m => (a -> b -> m b) -- ^ Worker function. -> b -- ^ Initial state when folding rest of segments. -> Option3 n Int b -- ^ Name, length and initial state for first segment. -> Chain m sLen (n, Int) -- ^ Segment names and lengths. -> Chain m sVal a -- ^ Input data to fold. -> Chain m (Folds sLen sVal n a b) (n, b) foldsC f zN s0 (Chain _szLens sLens0 stepLens) (Chain _szVals sVals0 stepVals) = Chain S.Unknown (init_foldsC s0) step where init_foldsC s = case s of None3 -> Folds sLens0 sVals0 None 0 zN Some3 n len acc -> Folds sLens0 sVals0 (Some n) len acc {-# NOINLINE init_foldsC #-} -- NOINLINE to hide the case match from the simplifier so it -- doesn't unswitch it at top-level and duplicate the follow-on code. step ss@(Folds sLens sVals nameSeg lenSeg valSeg) = case nameSeg of -- If we don't have a segment length we need to load the next one. None -> stepLens sLens >>= \rLens -> case rLens of -- We got a segment length, so load it into the state and -- initialise the accumulator. Yield (name, xLen) sLens' -> return $ Skip ss { _stateLens = sLens' , _nameSeg = Some name , _lenSeg = xLen , _valSeg = zN } -- Lengths input takes a step. Skip sLens' -> return $ Skip ss { _stateLens = sLens' } -- We're not currently folding a segment, and no more segment -- lengths are available, so we're done. Done sLens' -> return $ Done ss { _stateLens = sLens' } -- We're currently folding a segment. Some name -- We've reached the end of the segment, so emit the result. | lenSeg == 0 -> return $ Yield (name, valSeg) ss { _nameSeg = None } -- We still need more values for this segment. | otherwise -> stepVals sVals >>= \rVals -> case rVals of -- We got a new value, so accumulate it into the state. Yield xVal sVals' -> f xVal valSeg >>= \rAcc -> return $ Skip ss { _stateVals = sVals' , _lenSeg = lenSeg - 1 , _valSeg = rAcc } -- Vals input takes a step. Skip sVals' -> return $ Skip ss { _stateVals = sVals' } -- We're in a non-zero lengthed segment, but haven't got -- all the values, so we're done for now. Done sVals' -> return $ Done ss { _stateVals = sVals' } {-# INLINE_INNER step #-} {-# INLINE_STREAM foldsC #-} -- | Return state of a folds operation. data Folds sLens sVals n a b = Folds { -- | State of lengths chain. _stateLens :: !sLens -- | State of values chain. , _stateVals :: !sVals -- | If we're currently in a segment, then hold its name, , _nameSeg :: !(Option n) -- | Length of current segment. , _lenSeg :: !Int -- | Accumulated value of current segment. , _valSeg :: !b } deriving Show {- -- Defining folds in terms of weave doesn't work because if all the -- segment lengths are 0 then we don't want to load any values at all. = weaveC work s0 cLens cVals where work !ms !mxLen !mxVal = case ms of -- If we haven't got a current state then load the next -- segment length. None2 -> case mxLen of None -> return $ Finish ms MoveNone Some xLen -> return $ Next (Some2 xLen zN) MoveLeft Some2 len acc | len == 0 -> return $ Give acc None2 MoveNone | otherwise -> case mxVal of None -> return $ Finish ms MoveNone Some xVal -> do r <- f xVal acc return $ Next (Some2 (len - 1) r) MoveRight {-# INLINE [1] work #-} -- | Pack the weave state of a folds operation into a `Folds` record, -- which has better field names. packFolds :: Weave sLens Int sVals a (Option2 Int b) -> Folds sLens sVals a b packFolds (Weave stateL elemL _endL stateR elemR _endR mLenAcc) = (Folds stateL elemL stateR elemR mLenAcc) {-# INLINE packFolds #-} -}