#include "fusion-phases.h"
module Data.Array.Parallel.Unlifted.Stream.Segments
( streamSegsFromNestedUSSegd
, streamSegsFromVectorsUSSegd
, streamSegsFromVectorsUVSegd)
where
import Data.Vector.Fusion.Stream.Size
import Data.Vector.Fusion.Stream.Monadic
import Data.Array.Parallel.Unlifted.Sequential.Vector (Unbox, Vector, index)
import Data.Array.Parallel.Unlifted.Vectors (Unboxes, Vectors)
import Data.Array.Parallel.Unlifted.Sequential.USSegd (USSegd(..))
import Data.Array.Parallel.Unlifted.Sequential.UVSegd (UVSegd(..))
import qualified Data.Array.Parallel.Unlifted.Vectors as US
import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
import qualified Data.Array.Parallel.Unlifted.Sequential.USSegd as USSegd
import qualified Data.Array.Parallel.Unlifted.Sequential.UVSegd as UVSegd
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
import qualified Data.Primitive.ByteArray as P
import System.IO.Unsafe
streamSegsFromNestedUSSegd
:: (Unbox a, Monad m)
=> V.Vector (Vector a)
-> USSegd
-> Stream m a
streamSegsFromNestedUSSegd
pdatas
ussegd@(USSegd _ starts sources usegd)
= let
here = "streamSegsFromNestedUSSegd"
pseglens = USegd.takeLengths usegd
fn (pseg, ix)
| pseg >= USSegd.length ussegd
= return $ Done
| ix >= pseglens `U.unsafeIndex` pseg
= return $ Skip (pseg + 1, 0)
| otherwise
= let !srcid = index here sources pseg
!pdata = pdatas `V.unsafeIndex` srcid
!start = index here starts pseg
!result = index here pdata (start + ix)
in return $ Yield result (pseg, ix + 1)
in Stream fn (0, 0) Unknown
streamSegsFromVectorsUSSegd
:: (Unboxes a, Monad m)
=> Vectors a
-> USSegd
-> Stream m a
streamSegsFromVectorsUSSegd
vectors
ussegd@(USSegd _ segStarts segSources usegd)
= segStarts `seq` segSources `seq` usegd `seq` vectors `seq`
let here = "stremSegsFromVectorsUSSegd"
!segLens = USegd.takeLengths usegd
!segsTotal = USSegd.length ussegd
!elements = USegd.takeElements usegd
fnSeg (ixSeg, baSeg, ixEnd, ixElem)
= ixSeg `seq` baSeg `seq`
if ixElem >= ixEnd
then if ixSeg + 1 >= segsTotal
then return $ Done
else let ixSeg' = ixSeg + 1
sourceSeg = index here segSources ixSeg'
startSeg = index here segStarts ixSeg'
lenSeg = index here segLens ixSeg'
(arr, startArr, _)
= US.unsafeIndexUnpack vectors sourceSeg
in return $ Skip
( ixSeg'
, arr
, startArr + startSeg + lenSeg
, startArr + startSeg)
else let !result = P.indexByteArray baSeg ixElem
in return $ Yield result (ixSeg, baSeg, ixEnd, ixElem + 1)
!dummy = unsafePerformIO
$ P.newByteArray 0 >>= P.unsafeFreezeByteArray
!initState
= ( 1
, dummy
, 0
, 0)
in Stream fnSeg initState (Exact elements)
streamSegsFromVectorsUVSegd
:: (Unboxes a, Monad m)
=> Vectors a
-> UVSegd
-> Stream m a
streamSegsFromVectorsUVSegd
vectors
uvsegd@(UVSegd _ _ vsegids _ (USSegd _ segStarts segSources usegd) )
= segStarts `seq` segSources `seq` uvsegd `seq` vectors `seq`
let here = "stremSegsFromVectorsUVSegd"
!elemsTotal = U.sum $ UVSegd.takeLengths uvsegd
!segsTotal = UVSegd.length uvsegd
!segLens = USegd.takeLengths usegd
fnSeg (ixSeg, baSeg, ixEnd, ixElem)
= ixSeg `seq` baSeg `seq`
if ixElem >= ixEnd
then if ixSeg + 1 >= segsTotal
then return $ Done
else let ixSeg' = ixSeg + 1
ixPSeg = index here vsegids ixSeg'
sourceSeg = index here segSources ixPSeg
startSeg = index here segStarts ixPSeg
lenSeg = index here segLens ixPSeg
(arr, startArr, _)
= US.unsafeIndexUnpack vectors sourceSeg
in return $ Skip
( ixSeg'
, arr
, startArr + startSeg + lenSeg
, startArr + startSeg)
else let !result = P.indexByteArray baSeg ixElem
in return $ Yield result (ixSeg, baSeg, ixEnd, ixElem + 1)
!dummy = unsafePerformIO
$ P.newByteArray 0 >>= P.unsafeFreezeByteArray
!initState
= ( 1
, dummy
, 0
, 0)
in Stream fnSeg initState (Exact elemsTotal)