module Futhark.CodeGen.ImpGen.Multicore.SegScan
  ( compileSegScan,
  )
where

import Control.Monad
import Data.List (zip4)
import Futhark.CodeGen.ImpCode.Multicore qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Util.IntegralExp (quot, rem)
import Prelude hiding (quot, rem)

-- Compile a SegScan construct
compileSegScan ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.MCCode
compileSegScan :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen MCCode
compileSegScan Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody TV Int32
nsubtasks
  | [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
      Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen MCCode
nonsegmentedScan Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody TV Int32
nsubtasks
  | Bool
otherwise =
      Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> MulticoreGen MCCode
segmentedScan Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody

xParams, yParams :: SegBinOp MCMem -> [LParam MCMem]
xParams :: SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan =
  forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan)) (forall rep. Lambda rep -> [LParam rep]
lambdaParams (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp MCMem
scan))
yParams :: SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan =
  forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan)) (forall rep. Lambda rep -> [LParam rep]
lambdaParams (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp MCMem
scan))

lamBody :: SegBinOp MCMem -> Body MCMem
lamBody :: SegBinOp MCMem -> Body MCMem
lamBody = forall rep. Lambda rep -> Body rep
lambdaBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda

-- Arrays for storing worker results.
carryArrays :: String -> TV Int32 -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
carryArrays :: String -> TV Int32 -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
carryArrays String
s TV Int32
nsubtasks [SegBinOp MCMem]
segops =
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp MCMem]
segops forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda MCMem
lam [SubExp]
_ ShapeBase SubExp
shape) ->
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda MCMem
lam) forall a b. (a -> b) -> a -> b
$ \Type
t -> do
      let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType Type
t
          full_shape :: ShapeBase SubExp
full_shape =
            forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Int32
nsubtasks)]
              forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
              forall a. Semigroup a => a -> a -> a
<> forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
      forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
s PrimType
pt ShapeBase SubExp
full_shape Space
DefaultSpace

nonsegmentedScan ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.MCCode
nonsegmentedScan :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen MCCode
nonsegmentedScan Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody TV Int32
nsubtasks = do
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"nonsegmented segScan" forall a. Maybe a
Nothing
  forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    -- Are we working with nested arrays
    let dims :: [[SubExp]]
dims = forall a b. (a -> b) -> [a] -> [b]
map (forall d. ShapeBase d -> [d]
shapeDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape) [SegBinOp MCMem]
scan_ops
    -- Are we only working on scalars
    let scalars :: Bool
scalars = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall shape u. TypeBase shape u -> Bool
primType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Typed t => t -> Type
typeOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> dec
paramDec) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall rep. Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda)) [SegBinOp MCMem]
scan_ops Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[SubExp]]
dims
    -- Do we have nested vector operations
    let vectorize :: Bool
vectorize = [] forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [[SubExp]]
dims

    let param_types :: [Type]
param_types = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall rep. Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda)) [SegBinOp MCMem]
scan_ops
    let no_array_param :: Bool
no_array_param = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType [Type]
param_types

    let (Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> ImpM MCMem HostEnv Multicore ()
scanStage1, Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
scanStage3)
          | Bool
scalars = (Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> ImpM MCMem HostEnv Multicore ()
scanStage1Scalar, Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
scanStage3Scalar)
          | Bool
vectorize Bool -> Bool -> Bool
&& Bool
no_array_param = (Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> ImpM MCMem HostEnv Multicore ()
scanStage1Nested, Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
scanStage3Nested)
          | Bool
otherwise = (Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> ImpM MCMem HostEnv Multicore ()
scanStage1Fallback, Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
scanStage3Fallback)

    forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Scan stage 1" forall a. Maybe a
Nothing
    Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> ImpM MCMem HostEnv Multicore ()
scanStage1 Pat LParamMem
pat SegSpace
space KernelBody MCMem
kbody [SegBinOp MCMem]
scan_ops

    let nsubtasks' :: TPrimExp Int32 VName
nsubtasks' = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
nsubtasks
    forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
nsubtasks' forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int32 VName
1) forall a b. (a -> b) -> a -> b
$ do
      [SegBinOp MCMem]
scan_ops2 <- [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
scan_ops
      forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Scan stage 2" forall a. Maybe a
Nothing
      [[VName]]
carries <- Pat LParamMem
-> TV Int32
-> SegSpace
-> [SegBinOp MCMem]
-> MulticoreGen [[VName]]
scanStage2 Pat LParamMem
pat TV Int32
nsubtasks SegSpace
space [SegBinOp MCMem]
scan_ops2
      [SegBinOp MCMem]
scan_ops3 <- [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
scan_ops
      forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Scan stage 3" forall a. Maybe a
Nothing
      Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
scanStage3 Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops3 [[VName]]
carries

-- Different ways to generate code for a scan loop
data ScanLoopType
  = ScanSeq -- Fully sequential
  | ScanNested -- Nested vectorized map
  | ScanScalar -- Vectorized scan over scalars

-- Given a scan type, return a function to inject into the loop body
getScanLoop ::
  ScanLoopType ->
  (Imp.TExp Int64 -> MulticoreGen ()) ->
  MulticoreGen ()
getScanLoop :: ScanLoopType
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
getScanLoop ScanLoopType
ScanScalar = (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateUniformizeLoop
getScanLoop ScanLoopType
_ = \TExp Int64 -> ImpM MCMem HostEnv Multicore ()
body -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
body TExp Int64
0

-- Given a scan type, return a function to extract a scalar from a vector
getExtract :: ScanLoopType -> Imp.TExp Int64 -> MulticoreGen Imp.MCCode -> MulticoreGen ()
getExtract :: ScanLoopType
-> TExp Int64
-> MulticoreGen MCCode
-> ImpM MCMem HostEnv Multicore ()
getExtract ScanLoopType
ScanSeq = \TExp Int64
_ MulticoreGen MCCode
body -> MulticoreGen MCCode
body forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall op rep r. Code op -> ImpM rep r op ()
emit
getExtract ScanLoopType
_ = TExp Int64
-> MulticoreGen MCCode -> ImpM MCMem HostEnv Multicore ()
extractVectorLane

genBinOpParams :: [SegBinOp MCMem] -> MulticoreGen ()
genBinOpParams :: [SegBinOp MCMem] -> ImpM MCMem HostEnv Multicore ()
genBinOpParams [SegBinOp MCMem]
scan_ops =
  forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$
    forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall rep. Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp MCMem]
scan_ops

genLocalAccsStage1 :: [SegBinOp MCMem] -> MulticoreGen [[VName]]
genLocalAccsStage1 :: [SegBinOp MCMem] -> MulticoreGen [[VName]]
genLocalAccsStage1 [SegBinOp MCMem]
scan_ops = do
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp MCMem]
scan_ops forall a b. (a -> b) -> a -> b
$ \SegBinOp MCMem
scan_op -> do
    let shape :: ShapeBase SubExp
shape = forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp MCMem
scan_op
        ts :: [Type]
ts = forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp MCMem
scan_op
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) (forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) [Type]
ts) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne, Type
t) -> do
      VName
acc <- -- update accumulator to have type decoration
        case forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape of
          [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p
          [SubExp]
_ -> do
            let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType Type
t
            forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"local_acc" PrimType
pt (ShapeBase SubExp
shape forall a. Semigroup a => a -> a -> a
<> forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) Space
DefaultSpace

      -- Now neutral-initialise the accumulator.
      forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
        forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc [TExp Int64]
vec_is SubExp
ne []

      forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
acc

getNestLoop ::
  ScanLoopType ->
  Shape ->
  ([Imp.TExp Int64] -> MulticoreGen ()) ->
  MulticoreGen ()
getNestLoop :: ScanLoopType
-> ShapeBase SubExp
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
getNestLoop ScanLoopType
ScanNested = ShapeBase SubExp
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
sLoopNestVectorized
getNestLoop ScanLoopType
_ = forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest

applyScanOps ::
  ScanLoopType ->
  Pat LetDecMem ->
  SegSpace ->
  [SubExp] ->
  [SegBinOp MCMem] ->
  [[VName]] ->
  ImpM MCMem HostEnv Imp.Multicore ()
applyScanOps :: ScanLoopType
-> Pat LParamMem
-> SegSpace
-> [SubExp]
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
applyScanOps ScanLoopType
typ Pat LParamMem
pat SegSpace
space [SubExp]
all_scan_res [SegBinOp MCMem]
scan_ops [[VName]]
local_accs = do
  let per_scan_res :: [[SubExp]]
per_scan_res = forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops [SubExp]
all_scan_res
      per_scan_pes :: [[PatElem LParamMem]]
per_scan_pes = forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
  let ([VName]
is, [SubExp]
_) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  -- Potential vector load and then do sequential scan
  ScanLoopType
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
getScanLoop ScanLoopType
typ forall a b. (a -> b) -> a -> b
$ \TExp Int64
j ->
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[PatElem LParamMem]]
per_scan_pes [SegBinOp MCMem]
scan_ops [[SubExp]]
per_scan_res [[VName]]
local_accs) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
pes, SegBinOp MCMem
scan_op, [SubExp]
scan_res, [VName]
acc) ->
      ScanLoopType
-> ShapeBase SubExp
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
getNestLoop ScanLoopType
typ (forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
        forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Read accumulator" forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) [VName]
acc) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
acc') -> do
            forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc') [TExp Int64]
vec_is
        forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Read next values" forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan_op) [SubExp]
scan_res) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
se) ->
            ScanLoopType
-> TExp Int64
-> MulticoreGen MCCode
-> ImpM MCMem HostEnv Multicore ()
getExtract ScanLoopType
typ TExp Int64
j forall a b. (a -> b) -> a -> b
$
              forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$
                forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
se [TExp Int64]
vec_is
        -- Scan body
        forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Scan op body" forall a b. (a -> b) -> a -> b
$
          forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
acc [PatElem LParamMem]
pes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$
              \(VName
acc', PatElem LParamMem
pe, SubExp
se) -> do
                forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se []
                forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc' [TExp Int64]
vec_is SubExp
se []

-- Generate a loop which performs a potentially vectorized scan on the
-- result of a kernel body.
genScanLoop ::
  ScanLoopType ->
  Pat LetDecMem ->
  SegSpace ->
  KernelBody MCMem ->
  [SegBinOp MCMem] ->
  [[VName]] ->
  Imp.TExp Int64 ->
  ImpM MCMem HostEnv Imp.Multicore ()
genScanLoop :: ScanLoopType
-> Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> [[VName]]
-> TExp Int64
-> ImpM MCMem HostEnv Multicore ()
genScanLoop ScanLoopType
typ Pat LParamMem
pat SegSpace
space KernelBody MCMem
kbody [SegBinOp MCMem]
scan_ops [[VName]]
local_accs TExp Int64
i = do
  let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
        forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp MCMem]
scan_ops) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
  let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns

  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' TExp Int64
i
  forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) forall a b. (a -> b) -> a -> b
$ do
    let map_arrs :: [PatElem LParamMem]
map_arrs = forall a. Int -> [a] -> [a]
drop (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp MCMem]
scan_ops) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
    forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write mapped values results to memory" forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElem LParamMem
-> KernelResult
-> ImpM MCMem HostEnv Multicore ()
compileThreadResult SegSpace
space) [PatElem LParamMem]
map_arrs [KernelResult]
map_res
    forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Apply scan op" forall a b. (a -> b) -> a -> b
$
      ScanLoopType
-> Pat LParamMem
-> SegSpace
-> [SubExp]
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
applyScanOps ScanLoopType
typ Pat LParamMem
pat SegSpace
space (forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
all_scan_res) [SegBinOp MCMem]
scan_ops [[VName]]
local_accs

scanStage1Scalar ::
  Pat LetDecMem ->
  SegSpace ->
  KernelBody MCMem ->
  [SegBinOp MCMem] ->
  MulticoreGen ()
scanStage1Scalar :: Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> ImpM MCMem HostEnv Multicore ()
scanStage1Scalar Pat LParamMem
pat SegSpace
space KernelBody MCMem
kbody [SegBinOp MCMem]
scan_ops = do
  MCCode
fbody <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

    [SegBinOp MCMem] -> ImpM MCMem HostEnv Multicore ()
genBinOpParams [SegBinOp MCMem]
scan_ops
    [[VName]]
local_accs <- [SegBinOp MCMem] -> MulticoreGen [[VName]]
genLocalAccsStage1 [SegBinOp MCMem]
scan_ops
    ImpM MCMem HostEnv Multicore () -> ImpM MCMem HostEnv Multicore ()
inISPC forall a b. (a -> b) -> a -> b
$
      String
-> ChunkLoopVectorization
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegScan" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$
        ScanLoopType
-> Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> [[VName]]
-> TExp Int64
-> ImpM MCMem HostEnv Multicore ()
genScanLoop ScanLoopType
ScanScalar Pat LParamMem
pat SegSpace
space KernelBody MCMem
kbody [SegBinOp MCMem]
scan_ops [[VName]]
local_accs
  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"scan_stage_1" MCCode
fbody [Param]
free_params

scanStage1Nested ::
  Pat LetDecMem ->
  SegSpace ->
  KernelBody MCMem ->
  [SegBinOp MCMem] ->
  MulticoreGen ()
scanStage1Nested :: Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> ImpM MCMem HostEnv Multicore ()
scanStage1Nested Pat LParamMem
pat SegSpace
space KernelBody MCMem
kbody [SegBinOp MCMem]
scan_ops = do
  MCCode
fbody <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

    [[VName]]
local_accs <- [SegBinOp MCMem] -> MulticoreGen [[VName]]
genLocalAccsStage1 [SegBinOp MCMem]
scan_ops

    ImpM MCMem HostEnv Multicore () -> ImpM MCMem HostEnv Multicore ()
inISPC forall a b. (a -> b) -> a -> b
$ do
      [SegBinOp MCMem] -> ImpM MCMem HostEnv Multicore ()
genBinOpParams [SegBinOp MCMem]
scan_ops
      String
-> ChunkLoopVectorization
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegScan" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        ScanLoopType
-> Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> [[VName]]
-> TExp Int64
-> ImpM MCMem HostEnv Multicore ()
genScanLoop ScanLoopType
ScanNested Pat LParamMem
pat SegSpace
space KernelBody MCMem
kbody [SegBinOp MCMem]
scan_ops [[VName]]
local_accs TExp Int64
i

  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"scan_stage_1" MCCode
fbody [Param]
free_params

scanStage1Fallback ::
  Pat LetDecMem ->
  SegSpace ->
  KernelBody MCMem ->
  [SegBinOp MCMem] ->
  MulticoreGen ()
scanStage1Fallback :: Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> ImpM MCMem HostEnv Multicore ()
scanStage1Fallback Pat LParamMem
pat SegSpace
space KernelBody MCMem
kbody [SegBinOp MCMem]
scan_ops = do
  -- Stage 1 : each thread partially scans a chunk of the input
  -- Writes directly to the resulting array
  MCCode
fbody <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

    [SegBinOp MCMem] -> ImpM MCMem HostEnv Multicore ()
genBinOpParams [SegBinOp MCMem]
scan_ops
    [[VName]]
local_accs <- [SegBinOp MCMem] -> MulticoreGen [[VName]]
genLocalAccsStage1 [SegBinOp MCMem]
scan_ops

    String
-> ChunkLoopVectorization
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegScan" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$
      ScanLoopType
-> Pat LParamMem
-> SegSpace
-> KernelBody MCMem
-> [SegBinOp MCMem]
-> [[VName]]
-> TExp Int64
-> ImpM MCMem HostEnv Multicore ()
genScanLoop ScanLoopType
ScanSeq Pat LParamMem
pat SegSpace
space KernelBody MCMem
kbody [SegBinOp MCMem]
scan_ops [[VName]]
local_accs
  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"scan_stage_1" MCCode
fbody [Param]
free_params

scanStage2 ::
  Pat LetDecMem ->
  TV Int32 ->
  SegSpace ->
  [SegBinOp MCMem] ->
  MulticoreGen [[VName]]
scanStage2 :: Pat LParamMem
-> TV Int32
-> SegSpace
-> [SegBinOp MCMem]
-> MulticoreGen [[VName]]
scanStage2 Pat LParamMem
pat TV Int32
nsubtasks SegSpace
space [SegBinOp MCMem]
scan_ops = do
  let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
      per_scan_pes :: [[PatElem LParamMem]]
per_scan_pes = forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
      nsubtasks' :: TExp Int64
nsubtasks' = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
nsubtasks

  forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall rep. Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp MCMem]
scan_ops
  TV Int64
offset <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"offset" (TExp Int64
0 :: Imp.TExp Int64)
  let offset' :: TExp Int64
offset' = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
offset
  TV Int64
offset_index <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"offset_index" (TExp Int64
0 :: Imp.TExp Int64)
  let offset_index' :: TExp Int64
offset_index' = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
offset_index

  -- Parameters used to find the chunk sizes
  -- Perhaps get this information from ``scheduling information``
  -- instead of computing it manually here.
  let iter_pr_subtask :: TExp Int64
iter_pr_subtask = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64 forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
nsubtasks'
      remainder :: TExp Int64
remainder = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64 forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64
nsubtasks'

  [[VName]]
carries <- String -> TV Int32 -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
carryArrays String
"scan_stage_2_carry" TV Int32
nsubtasks [SegBinOp MCMem]
scan_ops
  forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"carry-in for first chunk is neutral" forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp MCMem]
scan_ops [[VName]]
carries) forall a b. (a -> b) -> a -> b
$ \(SegBinOp MCMem
scan_op, [VName]
carry) ->
      forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
carry forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \(VName
carry', SubExp
ne) ->
          forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
carry' (TExp Int64
0 forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is) SubExp
ne []

  -- Perform sequential scan over the last element of each chunk
  forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"scan carries" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" (TExp Int64
nsubtasks' forall a. Num a => a -> a -> a
- TExp Int64
1) forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    TV Int64
offset forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int64
iter_pr_subtask
    forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
remainder) (TV Int64
offset forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int64
offset' forall a. Num a => a -> a -> a
+ TExp Int64
1)
    TV Int64
offset_index forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int64
offset_index' forall a. Num a => a -> a -> a
+ TExp Int64
offset'
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
offset_index'

    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_scan_pes [SegBinOp MCMem]
scan_ops [[VName]]
carries) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
pes, SegBinOp MCMem
scan_op, [VName]
carry) ->
      forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
        forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Read carry" forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) [VName]
carry) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
carry') ->
            forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
carry') (TExp Int64
i forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)

        forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Read next values" forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan_op) [PatElem LParamMem]
pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
            forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) ((TExp Int64
offset_index' forall a. Num a => a -> a -> a
- TExp Int64
1) forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)

        forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
carry forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \(VName
carry', SubExp
se) -> do
            forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
carry' ((TExp Int64
i forall a. Num a => a -> a -> a
+ TExp Int64
1) forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is) SubExp
se []

  -- Return the array of carries for each chunk.
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [[VName]]
carries

scanStage3Scalar ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  [[VName]] ->
  MulticoreGen ()
scanStage3Scalar :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
scanStage3Scalar Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops [[VName]]
per_scan_carries = do
  let per_scan_pes :: [[PatElem LParamMem]]
per_scan_pes = forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
      ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns

  MCCode
body <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space

    ImpM MCMem HostEnv Multicore () -> ImpM MCMem HostEnv Multicore ()
inISPC forall a b. (a -> b) -> a -> b
$ do
      [SegBinOp MCMem] -> ImpM MCMem HostEnv Multicore ()
genBinOpParams [SegBinOp MCMem]
scan_ops
      forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load carry-in" forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [[VName]]
per_scan_carries [SegBinOp MCMem]
scan_ops) forall a b. (a -> b) -> a -> b
$ \([VName]
op_carries, SegBinOp MCMem
scan_op) ->
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) [VName]
op_carries) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
carries) ->
            forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
carries) [forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat SegSpace
space)]
      String
-> ChunkLoopVectorization
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegScan" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' TExp Int64
i
        forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load partial result" forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElem LParamMem]]
per_scan_pes [SegBinOp MCMem]
scan_ops) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
scan_pes, SegBinOp MCMem
scan_op) ->
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan_op) [PatElem LParamMem]
scan_pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
              forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is)
        forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"combine carry with partial result" forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElem LParamMem]]
per_scan_pes [SegBinOp MCMem]
scan_ops) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
scan_pes, SegBinOp MCMem
scan_op) ->
            forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
scan_pes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
se) ->
                forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
se []

  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"scan_stage_3" MCCode
body [Param]
free_params

scanStage3Nested ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  [[VName]] ->
  MulticoreGen ()
scanStage3Nested :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
scanStage3Nested Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops [[VName]]
per_scan_carries = do
  let per_scan_pes :: [[PatElem LParamMem]]
per_scan_pes = forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
      ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
  MCCode
body <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

    String
-> ChunkLoopVectorization
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegScan" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
      [SegBinOp MCMem] -> ImpM MCMem HostEnv Multicore ()
genBinOpParams [SegBinOp MCMem]
scan_ops
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' TExp Int64
i
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_scan_pes [[VName]]
per_scan_carries [SegBinOp MCMem]
scan_ops) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
scan_pes, [VName]
op_carries, SegBinOp MCMem
scan_op) -> do
        forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load carry-in" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) [VName]
op_carries) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
carries) ->
              forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
carries) (forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat SegSpace
space) forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)

          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load partial result" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan_op) [PatElem LParamMem]
scan_pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
              forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"combine carry with partial result" forall a b. (a -> b) -> a -> b
$
            forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
scan_pes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
se) ->
                forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se []

  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"scan_stage_3" MCCode
body [Param]
free_params

scanStage3Fallback ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  [[VName]] ->
  MulticoreGen ()
scanStage3Fallback :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> [[VName]]
-> ImpM MCMem HostEnv Multicore ()
scanStage3Fallback Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops [[VName]]
per_scan_carries = do
  let per_scan_pes :: [[PatElem LParamMem]]
per_scan_pes = forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
      ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
  MCCode
body <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

    [SegBinOp MCMem] -> ImpM MCMem HostEnv Multicore ()
genBinOpParams [SegBinOp MCMem]
scan_ops

    String
-> ChunkLoopVectorization
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegScan" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' TExp Int64
i
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_scan_pes [[VName]]
per_scan_carries [SegBinOp MCMem]
scan_ops) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
scan_pes, [VName]
op_carries, SegBinOp MCMem
scan_op) -> do
        forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load carry-in" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
xParams SegBinOp MCMem
scan_op) [VName]
op_carries) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
carries) ->
              forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
carries) (forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat SegSpace
space) forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)

          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load partial result" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp MCMem -> [LParam MCMem]
yParams SegBinOp MCMem
scan_op) [PatElem LParamMem]
scan_pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
              forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"combine carry with partial result" forall a b. (a -> b) -> a -> b
$
            forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
scan_pes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOp MCMem -> Body MCMem
lamBody SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
se) ->
                forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se []
  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"scan_stage_3" MCCode
body [Param]
free_params

-- Note: This isn't currently used anywhere.
-- This implementation for a Segmented scan only
-- parallelize over the segments and each segment is
-- scanned sequentially.
segmentedScan ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
segmentedScan :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> MulticoreGen MCCode
segmentedScan Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody = do
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segmented segScan" forall a. Maybe a
Nothing
  forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    MCCode
body <- Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> MulticoreGen MCCode
compileSegScanBody Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody
    [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
    forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"seg_scan" MCCode
body [Param]
free_params

compileSegScanBody ::
  Pat LetDecMem ->
  SegSpace ->
  [SegBinOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
compileSegScanBody :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> MulticoreGen MCCode
compileSegScanBody Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
scan_ops KernelBody MCMem
kbody = forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
  let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns

  forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

  let per_scan_pes :: [[PatElem LParamMem]]
per_scan_pes = forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
scan_ops forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
  String
-> ChunkLoopVectorization
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegScan" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$ \TExp Int64
segment_i -> do
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp MCMem]
scan_ops [[PatElem LParamMem]]
per_scan_pes) forall a b. (a -> b) -> a -> b
$ \(SegBinOp MCMem
scan_op, [PatElem LParamMem]
scan_pes) -> do
      forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp MCMem
scan_op
      let ([Param LParamMem]
scan_x_params, [Param LParamMem]
scan_y_params) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ (forall rep. Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) SegBinOp MCMem
scan_op

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
        forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []

      let inner_bound :: TExp Int64
inner_bound = forall a. [a] -> a
last [TExp Int64]
ns_64
      -- Perform a sequential scan over the segment ``segment_i``
      forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
inner_bound forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a. [a] -> [a]
init [VName]
is) forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. [a] -> [a]
init [TExp Int64]
ns_64) TExp Int64
segment_i
        forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a. [a] -> a
last [VName]
is) TExp Int64
i
        forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) forall a b. (a -> b) -> a -> b
$ do
          let ([KernelResult]
scan_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write to-scan values to parameters" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_y_params [KernelResult]
scan_res) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
se) ->
              forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) []

          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write mapped values results to memory" forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat) [KernelResult]
map_res) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
se) ->
              forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) []

          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"combine with carry and write to memory" forall a b. (a -> b) -> a -> b
$
            forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param LParamMem]
scan_x_params [PatElem LParamMem]
scan_pes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp MCMem
scan_op) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe, SubExp
se) -> do
                forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
se []
                forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
se []