{-# LANGUAGE TypeFamilies #-}

-- | We generate code for non-segmented/single-segment SegRed using
-- the basic approach outlined in the paper "Design and GPGPU
-- Performance of Futhark’s Redomap Construct" (ARRAY '16).  The main
-- deviations are:
--
-- * While we still use two-phase reduction, we use only a single
--   kernel, with the final workgroup to write a result (tracked via
--   an atomic counter) performing the final reduction as well.
--
-- * Instead of depending on storage layout transformations to handle
--   non-commutative reductions efficiently, we slide a
--   @groupsize@-sized window over the input, and perform a parallel
--   reduction for each window.  This sacrifices the notion of
--   efficient sequentialisation, but is sometimes faster and
--   definitely simpler and more predictable (and uses less auxiliary
--   storage).
--
-- For segmented reductions we use the approach from "Strategies for
-- Regular Segmented Reductions on GPU" (FHPC '17).  This involves
-- having two different strategies, and dynamically deciding which one
-- to use based on the number of segments and segment size. We use the
-- (static) @group_size@ to decide which of the following two
-- strategies to choose:
--
-- * Large: uses one or more groups to process a single segment. If
--   multiple groups are used per segment, the intermediate reduction
--   results must be recursively reduced, until there is only a single
--   value per segment.
--
--   Each thread /can/ read multiple elements, which will greatly
--   increase performance; however, if the reduction is
--   non-commutative we will have to use a less efficient traversal
--   (with interim group-wide reductions) to enable coalesced memory
--   accesses, just as in the non-segmented case.
--
-- * Small: is used to let each group process *multiple* segments
--   within a group. We will only use this approach when we can
--   process at least two segments within a single group. In those
--   cases, we would allocate a /whole/ group per segment with the
--   large strategy, but at most 50% of the threads in the group would
--   have any element to read, which becomes highly inefficient.
module Futhark.CodeGen.ImpGen.GPU.SegRed
  ( compileSegRed,
    compileSegRed',
    DoSegBody,
  )
where

import Control.Monad.Except
import Data.List (genericLength, zip7)
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.Transform.Rename
import Futhark.Util (chunks)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- | The maximum number of operators we support in a single SegRed.
-- This limit arises out of the static allocation of counters.
maxNumOps :: Int32
maxNumOps :: Int32
maxNumOps = Int32
10

-- | Code generation for the body of the SegRed, taking a continuation
-- for saving the results of the body.  The results should be
-- represented as a pairing of a t'SubExp' along with a list of
-- indexes into that t'SubExp' for reading the result.
type DoSegBody = ([(SubExp, [Imp.TExp Int64])] -> InKernelGen ()) -> InKernelGen ()

-- | Compile 'SegRed' instance to host-level code with calls to
-- various kernels.
compileSegRed ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegRed :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegRed Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds KernelBody GPUMem
body =
  Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont ->
    forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
reds) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body

      forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" forall a b. (a -> b) -> a -> b
$ do
        let map_arrs :: [PatElem LParamMem]
map_arrs = forall a. Int -> [a] -> [a]
drop (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
reds) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LParamMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LParamMem]
map_arrs [KernelResult]
map_res

      [(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat []

-- | Like 'compileSegRed', but where the body is a monadic action.
compileSegRed' ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  DoSegBody ->
  CallKernelGen ()
compileSegRed' :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body
  | forall i a. Num i => [a] -> i
genericLength [SegBinOp GPUMem]
reds forall a. Ord a => a -> a -> Bool
> Int32
maxNumOps =
      forall a. String -> a
compilerLimitationS forall a b. (a -> b) -> a -> b
$
        String
"compileSegRed': at most " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int32
maxNumOps forall a. [a] -> [a] -> [a]
++ String
" reduction operators are supported."
  | [(VName
_, Constant (IntValue (Int64Value Int64
1))), (VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
      Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pat LParamMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body
  | Bool
otherwise = do
      let group_size' :: TPrimExp Int64 VName
group_size' = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
          segment_size :: TPrimExp Int64 VName
segment_size = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
          use_small_segments :: TPrimExp Bool VName
use_small_segments = TPrimExp Int64 VName
segment_size forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
2 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
group_size'
      forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        TPrimExp Bool VName
use_small_segments
        (Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction Pat LParamMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body)
        (Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pat LParamMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body)
  where
    num_groups :: Count NumGroups SubExp
num_groups = SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
    group_size :: Count GroupSize SubExp
group_size = SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

-- | Prepare intermediate arrays for the reduction.  Prim-typed
-- arguments go in local memory (so we need to do the allocation of
-- those arrays inside the kernel), while array-typed arguments go in
-- global memory.  Allocations for the former have already been
-- performed.  This policy is baked into how the allocations are done
-- in ExplicitAllocations.
intermediateArrays ::
  Count GroupSize SubExp ->
  SubExp ->
  SegBinOp GPUMem ->
  InKernelGen [VName]
intermediateArrays :: Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays (Count SubExp
group_size) SubExp
num_threads (SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
nes Shape
_) = do
  let red_op_params :: [LParam GPUMem]
red_op_params = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op
      ([Param LParamMem]
red_acc_params, [Param LParamMem]
_) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [LParam GPUMem]
red_op_params
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LParamMem]
red_acc_params forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p ->
    case forall dec. Param dec -> dec
paramDec Param LParamMem
p of
      MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
        let shape' :: Shape
shape' = forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] forall a. Semigroup a => a -> a -> a
<> Shape
shape
        forall {k} (rep :: k) r op.
String
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray String
"red_arr" PrimType
pt Shape
shape' VName
mem forall a b. (a -> b) -> a -> b
$
          forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$
            forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$
              forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
      LParamMem
_ -> do
        let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p
            shape :: Shape
shape = forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
        forall {k} (rep :: k) r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"red_arr" PrimType
pt Shape
shape forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

-- | Arrays for storing group results.
--
-- The group-result arrays have an extra dimension because they are
-- also used for keeping vectorised accumulators for first-stage
-- reduction, if necessary.  If necessary, this dimension has size
-- group_size, and otherwise 1.  When actually storing group results,
-- the first index is set to 0.
groupResultArrays ::
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  [SegBinOp GPUMem] ->
  CallKernelGen [[VName]]
groupResultArrays :: Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp GPUMem]
-> CallKernelGen [[VName]]
groupResultArrays (Count SubExp
virt_num_groups) (Count SubExp
group_size) [SegBinOp GPUMem]
reds =
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp GPUMem]
reds forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
lam [SubExp]
_ Shape
shape) ->
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPUMem
lam) forall a b. (a -> b) -> a -> b
$ \Type
t -> do
      let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType Type
t
          extra_dim :: SubExp
extra_dim
            | forall shape u. TypeBase shape u -> Bool
primType Type
t, forall a. ArrayShape a => a -> Int
shapeRank Shape
shape forall a. Eq a => a -> a -> Bool
== Int
0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
            | Bool
otherwise = SubExp
group_size
          full_shape :: Shape
full_shape = forall d. [d] -> ShapeBase d
Shape [SubExp
extra_dim, SubExp
virt_num_groups] forall a. Semigroup a => a -> a -> a
<> Shape
shape forall a. Semigroup a => a -> a -> a
<> forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
          -- Move the groupsize dimension last to ensure coalesced
          -- memory access.
          perm :: [Int]
perm = [Int
1 .. forall a. ArrayShape a => a -> Int
shapeRank Shape
full_shape forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
0]
      forall {k} (rep :: k) r op.
String
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm String
"segred_tmp" PrimType
pt Shape
full_shape (String -> Space
Space String
"device") [Int]
perm

nonsegmentedReduction ::
  Pat LetDecMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  DoSegBody ->
  CallKernelGen ()
nonsegmentedReduction :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pat LParamMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count GroupSize SubExp
group_size
      global_tid :: TPrimExp Int64 VName
global_tid = forall a. a -> TPrimExp Int64 a
Imp.le64 forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
      w :: TPrimExp Int64 VName
w = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'

  VName
counter <-
    forall {k} (rep :: k) r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 forall a b. (a -> b) -> a -> b
$
      Int -> ArrayContents
Imp.ArrayZeros (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps)

  [[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp GPUMem]
-> CallKernelGen [[VName]]
groupResultArrays Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size [SegBinOp GPUMem]
reds

  TV Int64
num_threads <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" forall a b. (a -> b) -> a -> b
$
      forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'

  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed" forall a. Maybe a
Nothing

  String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segred_nonseg" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
    VName
sync_arr <- forall {k} (rep :: k) r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool (forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
    [[VName]]
reds_arrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp GPUMem]
reds

    -- Since this is the nonsegmented case, all outer segment IDs must
    -- necessarily be 0.
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
gtids forall a b. (a -> b) -> a -> b
$ \VName
v -> forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TPrimExp Int64 VName
0 :: Imp.TExp Int64)

    let num_elements :: Count Elements (TPrimExp Int64 VName)
num_elements = forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 VName
w
        elems_per_thread :: Count Elements (TPrimExp Int64 VName)
elems_per_thread =
          Count Elements (TPrimExp Int64 VName)
num_elements
            forall e. IntegralExp e => e -> e -> e
`divUp` forall a. a -> Count Elements a
Imp.elements (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelNumThreads KernelConstants
constants))

    [SegBinOpSlug]
slugs <-
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, [VName], [VName])
-> InKernelGen SegBinOpSlug
segBinOpSlug (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants) (KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants)) forall a b. (a -> b) -> a -> b
$
        forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
    [Lambda GPUMem]
reds_op_renamed <-
      KernelConstants
-> [(VName, TPrimExp Int64 VName)]
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
        KernelConstants
constants
        (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 VName]
dims')
        Count Elements (TPrimExp Int64 VName)
num_elements
        TPrimExp Int64 VName
global_tid
        Count Elements (TPrimExp Int64 VName)
elems_per_thread
        (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_threads)
        [SegBinOpSlug]
slugs
        DoSegBody
body

    let segred_pes :: [[PatElem LParamMem]]
segred_pes =
          forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
reds) forall a b. (a -> b) -> a -> b
$
            forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
segred_pat
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElem LParamMem]]
segred_pes [SegBinOpSlug]
slugs [Lambda GPUMem]
reds_op_renamed [Integer
0 ..]) forall a b. (a -> b) -> a -> b
$
      \(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
nes Shape
_, [VName]
red_arrs, [VName]
group_res_arrs, [PatElem LParamMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
red_op_renamed, Integer
i) -> do
        let ([Param LParamMem]
red_x_params, [Param LParamMem]
red_y_params) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op
        KernelConstants
-> [PatElem LParamMem]
-> TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> SegBinOpSlug
-> [LParam GPUMem]
-> [LParam GPUMem]
-> Lambda GPUMem
-> [SubExp]
-> TPrimExp Int32 VName
-> VName
-> TPrimExp Int32 VName
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
          KernelConstants
constants
          [PatElem LParamMem]
pes
          (KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants)
          TPrimExp Int32 VName
0
          [TPrimExp Int64 VName
0]
          TPrimExp Int64 VName
0
          (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelNumGroups KernelConstants
constants)
          SegBinOpSlug
slug
          [Param LParamMem]
red_x_params
          [Param LParamMem]
red_y_params
          Lambda GPUMem
red_op_renamed
          [SubExp]
nes
          TPrimExp Int32 VName
1
          VName
counter
          (forall a. Num a => Integer -> a
fromInteger Integer
i)
          VName
sync_arr
          [VName]
group_res_arrs
          [VName]
red_arrs

  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" forall a. Maybe a
Nothing

smallSegmentsReduction ::
  Pat LetDecMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  DoSegBody ->
  CallKernelGen ()
smallSegmentsReduction :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction (Pat [PatElem LParamMem]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'

  -- Careful to avoid division by zero now.
  TPrimExp Int64 VName
segment_size_nonzero <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segment_size_nonzero" forall a b. (a -> b) -> a -> b
$ forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 TPrimExp Int64 VName
segment_size

  let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count GroupSize SubExp
group_size
  TV Int64
num_threads <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
  let num_segments :: TPrimExp Int64 VName
num_segments = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims'
      segments_per_group :: TPrimExp Int64 VName
segments_per_group = forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size' forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
segment_size_nonzero
      required_groups :: TPrimExp Int32 VName
required_groups = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
num_segments forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName
segments_per_group

  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-small" forall a. Maybe a
Nothing
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_segments
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
segment_size
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segments_per_group" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
segments_per_group
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"required_groups" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int32 VName
required_groups

  String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segred_small" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
    [[VName]]
reds_arrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> VName
tvVar TV Int64
num_threads)) [SegBinOp GPUMem]
reds

    -- We probably do not have enough actual workgroups to cover the
    -- entire iteration space.  Some groups thus have to perform double
    -- duty; we put an outer loop to accomplish this.
    SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TPrimExp Int32 VName
required_groups forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
group_id' -> do
      -- Compute the 'n' input indices.  The outer 'n-1' correspond to
      -- the segment ID, and are computed from the group id.  The inner
      -- is computed from the local thread id, and may be out-of-bounds.
      let ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
          segment_index :: TPrimExp Int64 VName
segment_index =
            (TPrimExp Int64 VName
ltid forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
segment_size_nonzero)
              forall a. Num a => a -> a -> a
+ (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id' forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segments_per_group)
          index_within_segment :: TPrimExp Int64 VName
index_within_segment = TPrimExp Int64 VName
ltid forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size

      forall {k} (rep :: k) r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [a] -> [a]
init [VName]
gtids) (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims')) TPrimExp Int64 VName
segment_index
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a. [a] -> a
last [VName]
gtids) TPrimExp Int64 VName
index_within_segment

      let out_of_bounds :: InKernelGen ()
out_of_bounds =
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
reds [[VName]]
reds_arrs) forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
_ [SubExp]
nes Shape
_, [VName]
red_arrs) ->
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
red_arrs [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
ne) ->
                forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] SubExp
ne []

          in_bounds :: InKernelGen ()
in_bounds =
            DoSegBody
body forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])]
red_res ->
              forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save results to be reduced" forall a b. (a -> b) -> a -> b
$ do
                let red_dests :: [(VName, [TPrimExp Int64 VName])]
red_dests = forall a b. [a] -> [b] -> [(a, b)]
zip (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat [TPrimExp Int64 VName
ltid]
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [TPrimExp Int64 VName])]
red_dests [(SubExp, [TPrimExp Int64 VName])]
red_res) forall a b. (a -> b) -> a -> b
$ \((VName
d, [TPrimExp Int64 VName]
d_is), (SubExp
res, [TPrimExp Int64 VName]
res_is)) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
d [TPrimExp Int64 VName]
d_is SubExp
res [TPrimExp Int64 VName]
res_is

      forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply map function if in bounds" forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
          ( TPrimExp Int64 VName
segment_size
              forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 VName
0
              forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(VName, SubExp)] -> TPrimExp Bool VName
isActive (forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [SubExp]
dims)
              forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
ltid
                forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
segment_size
                  forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_group
          )
          InKernelGen ()
in_bounds
          InKernelGen ()
out_of_bounds

      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal -- Also implicitly barrier.
      let crossesSegment :: TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
crossesSegment TPrimExp Int32 VName
from TPrimExp Int32 VName
to =
            (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)
      forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
segment_size forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 VName
0) forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform segmented scan to imitate reduction" forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
reds [[VName]]
reds_arrs) forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
_ Shape
_, [VName]
red_arrs) ->
            Maybe
  (TPrimExp Int32 VName
   -> TPrimExp Int32 VName -> TPrimExp Bool VName)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
              (forall a. a -> Maybe a
Just TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
crossesSegment)
              (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_threads)
              (TPrimExp Int64 VName
segment_size forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_group)
              Lambda GPUMem
red_op
              [VName]
red_arrs

      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

      forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save final values of segments"
        forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen
          ( forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id'
              forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_group
              forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
ltid
              forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
num_segments
              forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
ltid
                forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
segments_per_group
          )
        forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
segred_pes (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs))
        forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
arr) -> do
          -- Figure out which segment result this thread should write...
          let flat_segment_index :: TPrimExp Int64 VName
flat_segment_index =
                forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id' forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_group forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
ltid
              gtids' :: [TPrimExp Int64 VName]
gtids' =
                forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') TPrimExp Int64 VName
flat_segment_index
          forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
            (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
            [TPrimExp Int64 VName]
gtids'
            (VName -> SubExp
Var VName
arr)
            [(TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segment_size_nonzero forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]

      -- Finally another barrier, because we will be writing to the
      -- local memory array first thing in the next iteration.
      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" forall a. Maybe a
Nothing

largeSegmentsReduction ::
  Pat LetDecMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  DoSegBody ->
  CallKernelGen ()
largeSegmentsReduction :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pat LParamMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      num_segments :: TPrimExp Int64 VName
num_segments = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims'
      segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
      num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count GroupSize SubExp
group_size

  (TPrimExp Int64 VName
groups_per_segment, Count Elements (TPrimExp Int64 VName)
elems_per_thread) <-
    TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Count NumGroups (TPrimExp Int64 VName)
-> Count GroupSize (TPrimExp Int64 VName)
-> CallKernelGen
     (TPrimExp Int64 VName, Count Elements (TPrimExp Int64 VName))
groupsPerSegmentAndElementsPerThread
      TPrimExp Int64 VName
segment_size
      TPrimExp Int64 VName
num_segments
      Count NumGroups (TPrimExp Int64 VName)
num_groups'
      Count GroupSize (TPrimExp Int64 VName)
group_size'
  TV Int64
virt_num_groups <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"virt_num_groups" forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName
groups_per_segment forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_segments

  TV Int64
num_threads <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" forall a b. (a -> b) -> a -> b
$
      forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'

  TV Int64
threads_per_segment <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"threads_per_segment" forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName
groups_per_segment forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'

  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-large" forall a. Maybe a
Nothing
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_segments
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
segment_size
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"virt_num_groups" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
virt_num_groups
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_groups" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count NumGroups (TPrimExp Int64 VName)
num_groups'
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"group_size" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"elems_per_thread" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elems_per_thread
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"groups_per_segment" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
groups_per_segment

  [[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp GPUMem]
-> CallKernelGen [[VName]]
groupResultArrays (forall {k} (u :: k) e. e -> Count u e
Count (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
virt_num_groups)) Count GroupSize SubExp
group_size [SegBinOp GPUMem]
reds

  -- In principle we should have a counter for every segment.  Since
  -- the number of segments is a dynamic quantity, we would have to
  -- allocate and zero out an array here, which is expensive.
  -- However, we exploit the fact that the number of segments being
  -- reduced at any point in time is limited by the number of
  -- workgroups. If we bound the number of workgroups, we can get away
  -- with using that many counters.  FIXME: Is this limit checked
  -- anywhere?  There are other places in the compiler that will fail
  -- if the group count exceeds the maximum group size, which is at
  -- most 1024 anyway.
  let num_counters :: Int
num_counters = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps forall a. Num a => a -> a -> a
* Int
1024
  VName
counter <-
    forall {k} (rep :: k) r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 forall a b. (a -> b) -> a -> b
$
      Int -> ArrayContents
Imp.ArrayZeros Int
num_counters

  String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segred_large" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
    [[VName]]
reds_arrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp GPUMem]
reds
    VName
sync_arr <- forall {k} (rep :: k) r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool (forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

    -- We probably do not have enough actual workgroups to cover the
    -- entire iteration space.  Some groups thus have to perform double
    -- duty; we put an outer loop to accomplish this.
    SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
virt_num_groups)) forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
group_id -> do
      let segment_gtids :: [VName]
segment_gtids = forall a. [a] -> [a]
init [VName]
gtids
          w :: SubExp
w = forall a. [a] -> a
last [SubExp]
dims
          local_tid :: TPrimExp Int32 VName
local_tid = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants

      TPrimExp Int32 VName
flat_segment_id <-
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_segment_id" forall a b. (a -> b) -> a -> b
$
          TPrimExp Int32 VName
group_id forall e. IntegralExp e => e -> e -> e
`quot` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
groups_per_segment

      TPrimExp Int64 VName
global_tid <-
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" forall a b. (a -> b) -> a -> b
$
          (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size') forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
local_tid)
            forall e. IntegralExp e => e -> e -> e
`rem` (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size') forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
groups_per_segment)

      let first_group_for_segment :: TPrimExp Int64 VName
first_group_for_segment = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
flat_segment_id forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
groups_per_segment
      forall {k} (rep :: k) r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
segment_gtids (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims')) forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
flat_segment_id
      forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (forall a. [a] -> a
last [VName]
gtids) PrimType
int64
      let num_elements :: Count Elements (TPrimExp Int64 VName)
num_elements = forall a. a -> Count Elements a
Imp.elements forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
w

      [SegBinOpSlug]
slugs <-
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, [VName], [VName])
-> InKernelGen SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
local_tid TPrimExp Int32 VName
group_id) forall a b. (a -> b) -> a -> b
$
          forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
      [Lambda GPUMem]
reds_op_renamed <-
        KernelConstants
-> [(VName, TPrimExp Int64 VName)]
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
          KernelConstants
constants
          (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 VName]
dims')
          Count Elements (TPrimExp Int64 VName)
num_elements
          TPrimExp Int64 VName
global_tid
          Count Elements (TPrimExp Int64 VName)
elems_per_thread
          (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
threads_per_segment)
          [SegBinOpSlug]
slugs
          DoSegBody
body

      let segred_pes :: [[PatElem LParamMem]]
segred_pes =
            forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
reds) forall a b. (a -> b) -> a -> b
$
              forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
segred_pat

          multiple_groups_per_segment :: InKernelGen ()
multiple_groups_per_segment =
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElem LParamMem]]
segred_pes [SegBinOpSlug]
slugs [Lambda GPUMem]
reds_op_renamed [Integer
0 ..]) forall a b. (a -> b) -> a -> b
$
              \(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
nes Shape
_, [VName]
red_arrs, [VName]
group_res_arrs, [PatElem LParamMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
red_op_renamed, Integer
i) -> do
                let ([Param LParamMem]
red_x_params, [Param LParamMem]
red_y_params) =
                      forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op
                KernelConstants
-> [PatElem LParamMem]
-> TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> SegBinOpSlug
-> [LParam GPUMem]
-> [LParam GPUMem]
-> Lambda GPUMem
-> [SubExp]
-> TPrimExp Int32 VName
-> VName
-> TPrimExp Int32 VName
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
                  KernelConstants
constants
                  [PatElem LParamMem]
pes
                  TPrimExp Int32 VName
group_id
                  TPrimExp Int32 VName
flat_segment_id
                  (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_gtids)
                  (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
first_group_for_segment)
                  TPrimExp Int64 VName
groups_per_segment
                  SegBinOpSlug
slug
                  [Param LParamMem]
red_x_params
                  [Param LParamMem]
red_y_params
                  Lambda GPUMem
red_op_renamed
                  [SubExp]
nes
                  (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters)
                  VName
counter
                  (forall a. Num a => Integer -> a
fromInteger Integer
i)
                  VName
sync_arr
                  [VName]
group_res_arrs
                  [VName]
red_arrs

          one_group_per_segment :: InKernelGen ()
one_group_per_segment =
            forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
"first thread in group saves final result to memory" forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElem LParamMem]]
segred_pes) forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElem LParamMem]
pes) ->
                forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
local_tid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) forall a b. (a -> b) -> a -> b
$
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
pes (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
v, (VName
acc, [TPrimExp Int64 VName]
acc_is)) ->
                    forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
v) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_gtids) (VName -> SubExp
Var VName
acc) [TPrimExp Int64 VName]
acc_is

      forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TPrimExp Int64 VName
groups_per_segment forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1) InKernelGen ()
one_group_per_segment InKernelGen ()
multiple_groups_per_segment

  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" forall a. Maybe a
Nothing

-- Careful to avoid division by zero here.  We have at least one group
-- per segment.
groupsPerSegmentAndElementsPerThread ::
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Count NumGroups (Imp.TExp Int64) ->
  Count GroupSize (Imp.TExp Int64) ->
  CallKernelGen
    ( Imp.TExp Int64,
      Imp.Count Imp.Elements (Imp.TExp Int64)
    )
groupsPerSegmentAndElementsPerThread :: TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Count NumGroups (TPrimExp Int64 VName)
-> Count GroupSize (TPrimExp Int64 VName)
-> CallKernelGen
     (TPrimExp Int64 VName, Count Elements (TPrimExp Int64 VName))
groupsPerSegmentAndElementsPerThread TPrimExp Int64 VName
segment_size TPrimExp Int64 VName
num_segments Count NumGroups (TPrimExp Int64 VName)
num_groups_hint Count GroupSize (TPrimExp Int64 VName)
group_size = do
  TPrimExp Int64 VName
groups_per_segment <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"groups_per_segment" forall a b. (a -> b) -> a -> b
$
      forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups_hint forall e. IntegralExp e => e -> e -> e
`divUp` forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 TPrimExp Int64 VName
num_segments
  TPrimExp Int64 VName
elements_per_thread <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"elements_per_thread" forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName
segment_size forall e. IntegralExp e => e -> e -> e
`divUp` (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
groups_per_segment)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName
groups_per_segment, forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 VName
elements_per_thread)

-- | A SegBinOp with auxiliary information.
data SegBinOpSlug = SegBinOpSlug
  { SegBinOpSlug -> SegBinOp GPUMem
slugOp :: SegBinOp GPUMem,
    -- | The arrays used for computing the intra-group reduction
    -- (either local or global memory).
    SegBinOpSlug -> [VName]
slugArrs :: [VName],
    -- | Places to store accumulator in stage 1 reduction.
    SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs :: [(VName, [Imp.TExp Int64])]
  }

slugBody :: SegBinOpSlug -> Body GPUMem
slugBody :: SegBinOpSlug -> Body GPUMem
slugBody = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp

slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp

slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp

slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp

slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). SegBinOp rep -> Commutativity
segBinOpComm forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp)

accParams, nextParams :: SegBinOpSlug -> [LParam GPUMem]
accParams :: SegBinOpSlug -> [LParam GPUMem]
accParams SegBinOpSlug
slug = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam GPUMem]
nextParams SegBinOpSlug
slug = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug

segBinOpSlug :: Imp.TExp Int32 -> Imp.TExp Int32 -> (SegBinOp GPUMem, [VName], [VName]) -> InKernelGen SegBinOpSlug
segBinOpSlug :: TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, [VName], [VName])
-> InKernelGen SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
local_tid TPrimExp Int32 VName
group_id (SegBinOp GPUMem
op, [VName]
group_res_arrs, [VName]
param_arrs) =
  SegBinOp GPUMem
-> [VName] -> [(VName, [TPrimExp Int64 VName])] -> SegBinOpSlug
SegBinOpSlug SegBinOp GPUMem
op [VName]
group_res_arrs
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param LParamMem
-> VName
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 VName])
mkAcc (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)) [VName]
param_arrs
  where
    mkAcc :: Param LParamMem
-> VName
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 VName])
mkAcc Param LParamMem
p VName
param_arr
      | Prim PrimType
t <- forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p,
        forall a. ArrayShape a => a -> Int
shapeRank (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
op) forall a. Eq a => a -> a -> Bool
== Int
0 = do
          TV Any
acc <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim (VName -> String
baseString (forall dec. Param dec -> VName
paramName Param LParamMem
p) forall a. Semigroup a => a -> a -> a
<> String
"_acc") PrimType
t
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc, [])
      | Bool
otherwise =
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
param_arr, [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
local_tid, forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id])

computeThreadChunkSize ::
  Commutativity ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  TV Int64 ->
  ImpM rep r op ()
computeThreadChunkSize :: forall {k} (rep :: k) r op.
Commutativity
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> Count Elements (TPrimExp Int64 VName)
-> TV Int64
-> ImpM rep r op ()
computeThreadChunkSize Commutativity
Commutative TPrimExp Int64 VName
threads_per_segment TPrimExp Int64 VName
thread_index Count Elements (TPrimExp Int64 VName)
elements_per_thread Count Elements (TPrimExp Int64 VName)
num_elements TV Int64
chunk_var =
  TV Int64
chunk_var
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64
      (forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elements_per_thread)
      ((forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
num_elements forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
thread_index) forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName
threads_per_segment)
computeThreadChunkSize Commutativity
Noncommutative TPrimExp Int64 VName
_ TPrimExp Int64 VName
thread_index Count Elements (TPrimExp Int64 VName)
elements_per_thread Count Elements (TPrimExp Int64 VName)
num_elements TV Int64
chunk_var = do
  TV Int64
starting_point <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"starting_point" forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName
thread_index forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elements_per_thread
  TV Int64
remaining_elements <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"remaining_elements" forall a b. (a -> b) -> a -> b
$
      forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
num_elements forall a. Num a => a -> a -> a
- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
starting_point

  let no_remaining_elements :: TPrimExp Bool VName
no_remaining_elements = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
remaining_elements forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
0
      beyond_bounds :: TPrimExp Bool VName
beyond_bounds = forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
num_elements forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
starting_point

  forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
    (TPrimExp Bool VName
no_remaining_elements forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TPrimExp Bool VName
beyond_bounds)
    (TV Int64
chunk_var forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int64 VName
0)
    ( forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        TPrimExp Bool VName
is_last_thread
        (TV Int64
chunk_var forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
last_thread_elements)
        (TV Int64
chunk_var forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elements_per_thread)
    )
  where
    last_thread_elements :: Count Elements (TPrimExp Int64 VName)
last_thread_elements =
      Count Elements (TPrimExp Int64 VName)
num_elements forall a. Num a => a -> a -> a
- forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 VName
thread_index forall a. Num a => a -> a -> a
* Count Elements (TPrimExp Int64 VName)
elements_per_thread
    is_last_thread :: TPrimExp Bool VName
is_last_thread =
      forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
num_elements
        forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TPrimExp Int64 VName
thread_index forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
          forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elements_per_thread

reductionStageZero ::
  KernelConstants ->
  [(VName, Imp.TExp Int64)] ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.TExp Int64 ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.TExp Int64 ->
  [SegBinOpSlug] ->
  DoSegBody ->
  InKernelGen ([Lambda GPUMem], InKernelGen ())
reductionStageZero :: KernelConstants
-> [(VName, TPrimExp Int64 VName)]
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda GPUMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TPrimExp Int64 VName)]
ispace Count Elements (TPrimExp Int64 VName)
num_elements TPrimExp Int64 VName
global_tid Count Elements (TPrimExp Int64 VName)
elems_per_thread TPrimExp Int64 VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
  let ([VName]
gtids, [TPrimExp Int64 VName]
_dims) = forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, TPrimExp Int64 VName)]
ispace
      gtid :: TV Int64
gtid = forall {k} (t :: k). VName -> PrimType -> TV t
mkTV (forall a. [a] -> a
last [VName]
gtids) PrimType
int64
      local_tid :: TPrimExp Int64 VName
local_tid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants

  -- Figure out how many elements this thread should process.
  TV Int64
chunk_size <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"chunk_size" PrimType
int64
  forall {k} (rep :: k) r op.
Commutativity
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> Count Elements (TPrimExp Int64 VName)
-> TV Int64
-> ImpM rep r op ()
computeThreadChunkSize
    ([SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs)
    TPrimExp Int64 VName
threads_per_segment
    (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
global_tid)
    Count Elements (TPrimExp Int64 VName)
elems_per_thread
    Count Elements (TPrimExp Int64 VName)
num_elements
    TV Int64
chunk_size

  forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam GPUMem]
slugParams [SegBinOpSlug]
slugs

  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"neutral-initialise the accumulators" forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 VName]
acc_is), SubExp
ne) ->
        forall {k} (rep :: k) r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is ->
          forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) SubExp
ne []

  [Lambda GPUMem]
slugs_op_renamed <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp) [SegBinOpSlug]
slugs

  let doTheReduction :: InKernelGen ()
doTheReduction =
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Lambda GPUMem]
slugs_op_renamed [SegBinOpSlug]
slugs) forall a b. (a -> b) -> a -> b
$ \(Lambda GPUMem
slug_op_renamed, SegBinOpSlug
slug) ->
          forall {k} (rep :: k) r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
            forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
"to reduce current chunk, first store our result in memory" forall a b. (a -> b) -> a -> b
$ do
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (VName
acc, [TPrimExp Int64 VName]
acc_is)) ->
                forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)

              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug) (SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LParamMem
p) ->
                forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p) forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
local_tid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p) []

            forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal -- Also implicitly barrier.
            TPrimExp Int32 VName -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)) Lambda GPUMem
slug_op_renamed (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug)

            forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

            forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"first thread saves the result in accumulator" forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
local_tid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0) forall a b. (a -> b) -> a -> b
$
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
slug_op_renamed)) forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 VName]
acc_is), Param LParamMem
p) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p) []

  -- If this is a non-commutative reduction, each thread must run the
  -- loop the same number of iterations, because we will be performing
  -- a group-wide reduction in there.
  let comm :: Commutativity
comm = [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs
      (TPrimExp Int64 VName
bound, InKernelGen () -> InKernelGen ()
check_bounds) =
        case Commutativity
comm of
          Commutativity
Commutative -> (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_size, forall a. a -> a
id)
          Commutativity
Noncommutative ->
            ( forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elems_per_thread,
              forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
gtid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
num_elements)
            )

  forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
bound forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
    TV Int64
gtid
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- case Commutativity
comm of
        Commutativity
Commutative ->
          TPrimExp Int64 VName
global_tid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
threads_per_segment forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i
        Commutativity
Noncommutative ->
          let index_in_segment :: TPrimExp Int64 VName
index_in_segment = TPrimExp Int64 VName
global_tid forall e. IntegralExp e => e -> e -> e
`quot` KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
           in forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
local_tid
                forall a. Num a => a -> a -> a
+ (TPrimExp Int64 VName
index_in_segment forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elems_per_thread forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i)
                  forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants

    InKernelGen () -> InKernelGen ()
check_bounds forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply map function" forall a b. (a -> b) -> a -> b
$
        DoSegBody
body forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])]
all_red_res -> do
          let slugs_res :: [[(SubExp, [TPrimExp Int64 VName])]]
slugs_res = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [TPrimExp Int64 VName])]
all_red_res

          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[(SubExp, [TPrimExp Int64 VName])]]
slugs_res) forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [(SubExp, [TPrimExp Int64 VName])]
red_res) ->
            forall {k} (rep :: k) r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
              forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load accumulator" forall a b. (a -> b) -> a -> b
$
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (VName
acc, [TPrimExp Int64 VName]
acc_is)) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
              forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load new values" forall a b. (a -> b) -> a -> b
$
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
nextParams SegBinOpSlug
slug) [(SubExp, [TPrimExp Int64 VName])]
red_res) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
res, [TPrimExp Int64 VName]
res_is)) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TPrimExp Int64 VName]
res_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
              forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply reduction operator"
                forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug)
                forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"store in accumulator"
                forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
                  ( forall a b. [a] -> [b] -> [(a, b)]
zip
                      (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug)
                      (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug)
                  )
                forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 VName]
acc_is), SubExp
se) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) SubExp
se []

    case Commutativity
comm of
      Commutativity
Noncommutative -> do
        InKernelGen ()
doTheReduction
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"first thread keeps accumulator; others reset to neutral element" forall a b. (a -> b) -> a -> b
$ do
          let reset_to_neutral :: InKernelGen ()
reset_to_neutral =
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 VName]
acc_is), SubExp
ne) ->
                    forall {k} (rep :: k) r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is ->
                      forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) SubExp
ne []
          forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TPrimExp Int64 VName
local_tid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0) InKernelGen ()
reset_to_neutral
      Commutativity
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Lambda GPUMem]
slugs_op_renamed, InKernelGen ()
doTheReduction)

reductionStageOne ::
  KernelConstants ->
  [(VName, Imp.TExp Int64)] ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.TExp Int64 ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.TExp Int64 ->
  [SegBinOpSlug] ->
  DoSegBody ->
  InKernelGen [Lambda GPUMem]
reductionStageOne :: KernelConstants
-> [(VName, TPrimExp Int64 VName)]
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne KernelConstants
constants [(VName, TPrimExp Int64 VName)]
ispace Count Elements (TPrimExp Int64 VName)
num_elements TPrimExp Int64 VName
global_tid Count Elements (TPrimExp Int64 VName)
elems_per_thread TPrimExp Int64 VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
  ([Lambda GPUMem]
slugs_op_renamed, InKernelGen ()
doTheReduction) <-
    KernelConstants
-> [(VName, TPrimExp Int64 VName)]
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda GPUMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TPrimExp Int64 VName)]
ispace Count Elements (TPrimExp Int64 VName)
num_elements TPrimExp Int64 VName
global_tid Count Elements (TPrimExp Int64 VName)
elems_per_thread TPrimExp Int64 VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body

  case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
    Commutativity
Noncommutative -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Commutativity
Commutative -> InKernelGen ()
doTheReduction

  forall (f :: * -> *) a. Applicative f => a -> f a
pure [Lambda GPUMem]
slugs_op_renamed

reductionStageTwo ::
  KernelConstants ->
  [PatElem LetDecMem] ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  [Imp.TExp Int64] ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  SegBinOpSlug ->
  [LParam GPUMem] ->
  [LParam GPUMem] ->
  Lambda GPUMem ->
  [SubExp] ->
  Imp.TExp Int32 ->
  VName ->
  Imp.TExp Int32 ->
  VName ->
  [VName] ->
  [VName] ->
  InKernelGen ()
reductionStageTwo :: KernelConstants
-> [PatElem LParamMem]
-> TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> SegBinOpSlug
-> [LParam GPUMem]
-> [LParam GPUMem]
-> Lambda GPUMem
-> [SubExp]
-> TPrimExp Int32 VName
-> VName
-> TPrimExp Int32 VName
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
  KernelConstants
constants
  [PatElem LParamMem]
segred_pes
  TPrimExp Int32 VName
group_id
  TPrimExp Int32 VName
flat_segment_id
  [TPrimExp Int64 VName]
segment_gtids
  TPrimExp Int64 VName
first_group_for_segment
  TPrimExp Int64 VName
groups_per_segment
  SegBinOpSlug
slug
  [LParam GPUMem]
red_x_params
  [LParam GPUMem]
red_y_params
  Lambda GPUMem
red_op_renamed
  [SubExp]
nes
  TPrimExp Int32 VName
num_counters
  VName
counter
  TPrimExp Int32 VName
counter_i
  VName
sync_arr
  [VName]
group_res_arrs
  [VName]
red_arrs = do
    let local_tid :: TPrimExp Int32 VName
local_tid = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
        group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
    TV Int64
old_counter <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"old_counter" PrimType
int32
    (VName
counter_mem, Space
_, Count Elements (TPrimExp Int64 VName)
counter_offset) <-
      forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray
        VName
counter
        [ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$
            TPrimExp Int32 VName
counter_i forall a. Num a => a -> a -> a
* TPrimExp Int32 VName
num_counters
              forall a. Num a => a -> a -> a
+ TPrimExp Int32 VName
flat_segment_id forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int32 VName
num_counters
        ]
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
"first thread in group saves group result to global memory" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
local_tid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) forall a b. (a -> b) -> a -> b
$ do
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
group_res_arrs (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(VName
v, (VName
acc, [TPrimExp Int64 VName]
acc_is)) ->
          forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
v [TPrimExp Int64 VName
0, forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id] (VName -> SubExp
Var VName
acc) [TPrimExp Int64 VName]
acc_is
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
        -- Increment the counter, thus stating that our result is
        -- available.
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp
          forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace
          forall a b. (a -> b) -> a -> b
$ IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd
            IntType
Int32
            (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
old_counter)
            VName
counter_mem
            Count Elements (TPrimExp Int64 VName)
counter_offset
          forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 VName
1 :: Imp.TExp Int32)
        -- Now check if we were the last group to write our result.  If
        -- so, it is our responsibility to produce the final result.
        forall {k} (rep :: k) r op.
VName -> [TPrimExp Int64 VName] -> Exp -> ImpM rep r op ()
sWrite VName
sync_arr [TPrimExp Int64 VName
0] forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
old_counter forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
groups_per_segment forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1

    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

    TV Bool
is_last_group <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"is_last_group" PrimType
Bool
    forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Bool
is_last_group) [] (VName -> SubExp
Var VName
sync_arr) [TPrimExp Int64 VName
0]
    forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Bool
is_last_group) forall a b. (a -> b) -> a -> b
$ do
      -- The final group has written its result (and it was
      -- us!), so read in all the group results and perform the
      -- final stage of the reduction.  But first, we reset the
      -- counter so it is ready for next time.  This is done
      -- with an atomic to avoid warnings about write/write
      -- races in oclgrind.
      forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
local_tid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) forall a b. (a -> b) -> a -> b
$
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace forall a b. (a -> b) -> a -> b
$
            IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd IntType
Int32 (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
old_counter) VName
counter_mem Count Elements (TPrimExp Int64 VName)
counter_offset forall a b. (a -> b) -> a -> b
$
              forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
                forall a. Num a => a -> a
negate TPrimExp Int64 VName
groups_per_segment

      forall {k} (rep :: k) r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
        -- There is no guarantee that the number of workgroups for the
        -- segment is less than the workgroup size, so each thread may
        -- have to read multiple elements.  We do this in a sequential
        -- way that may induce non-coalesced accesses, but the total
        -- number of accesses should be tiny here.
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
"read in the per-group-results" forall a b. (a -> b) -> a -> b
$ do
          TPrimExp Int64 VName
read_per_thread <-
            forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"read_per_thread" forall a b. (a -> b) -> a -> b
$
              TPrimExp Int64 VName
groups_per_segment forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
group_size

          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
red_x_params [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
            forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []

          forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
read_per_thread forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
            TPrimExp Int64 VName
group_res_id <-
              forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_res_id" forall a b. (a -> b) -> a -> b
$
                forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
local_tid forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
read_per_thread forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i
            TPrimExp Int64 VName
index_of_group_res <-
              forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"index_of_group_res" forall a b. (a -> b) -> a -> b
$
                TPrimExp Int64 VName
first_group_for_segment forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_res_id

            forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
group_res_id forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
groups_per_segment) forall a b. (a -> b) -> a -> b
$ do
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
red_y_params [VName]
group_res_arrs) forall a b. (a -> b) -> a -> b
$
                \(Param LParamMem
p, VName
group_res_arr) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                    (forall dec. Param dec -> VName
paramName Param LParamMem
p)
                    []
                    (VName -> SubExp
Var VName
group_res_arr)
                    ([TPrimExp Int64 VName
0, TPrimExp Int64 VName
index_of_group_res] forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)

              forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
red_x_params forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
se) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
se []

        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
red_x_params [VName]
red_arrs) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
arr) ->
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p) forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
local_tid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p) []

        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"reduce the per-group results" forall a b. (a -> b) -> a -> b
$ do
          TPrimExp Int32 VName -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
group_size) Lambda GPUMem
red_op_renamed [VName]
red_arrs

          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"and back to memory with the final result" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
local_tid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
segred_pes forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op_renamed) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, Param LParamMem
p) ->
                forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                  (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                  ([TPrimExp Int64 VName]
segment_gtids forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
                  (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p)
                  []