{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | We generate code for non-segmented/single-segment SegRed using
-- the basic approach outlined in the paper "Design and GPGPU
-- Performance of Futhark’s Redomap Construct" (ARRAY '16).  The main
-- deviations are:
--
-- * While we still use two-phase reduction, we use only a single
--   kernel, with the final workgroup to write a result (tracked via
--   an atomic counter) performing the final reduction as well.
--
-- * Instead of depending on storage layout transformations to handle
--   non-commutative reductions efficiently, we slide a
--   @groupsize@-sized window over the input, and perform a parallel
--   reduction for each window.  This sacrifices the notion of
--   efficient sequentialisation, but is sometimes faster and
--   definitely simpler and more predictable (and uses less auxiliary
--   storage).
--
-- For segmented reductions we use the approach from "Strategies for
-- Regular Segmented Reductions on GPU" (FHPC '17).  This involves
-- having two different strategies, and dynamically deciding which one
-- to use based on the number of segments and segment size. We use the
-- (static) @group_size@ to decide which of the following two
-- strategies to choose:
--
-- * Large: uses one or more groups to process a single segment. If
--   multiple groups are used per segment, the intermediate reduction
--   results must be recursively reduced, until there is only a single
--   value per segment.
--
--   Each thread /can/ read multiple elements, which will greatly
--   increase performance; however, if the reduction is
--   non-commutative we will have to use a less efficient traversal
--   (with interim group-wide reductions) to enable coalesced memory
--   accesses, just as in the non-segmented case.
--
-- * Small: is used to let each group process *multiple* segments
--   within a group. We will only use this approach when we can
--   process at least two segments within a single group. In those
--   cases, we would allocate a /whole/ group per segment with the
--   large strategy, but at most 50% of the threads in the group would
--   have any element to read, which becomes highly inefficient.
module Futhark.CodeGen.ImpGen.GPU.SegRed
  ( compileSegRed,
    compileSegRed',
    DoSegBody,
  )
where

import Control.Monad.Except
import Data.List (genericLength, zip7)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Error
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (chunks)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- | The maximum number of operators we support in a single SegRed.
-- This limit arises out of the static allocation of counters.
maxNumOps :: Int32
maxNumOps :: Int32
maxNumOps = Int32
10

-- | Code generation for the body of the SegRed, taking a continuation
-- for saving the results of the body.  The results should be
-- represented as a pairing of a t'SubExp' along with a list of
-- indexes into that t'SubExp' for reading the result.
type DoSegBody = ([(SubExp, [Imp.TExp Int64])] -> InKernelGen ()) -> InKernelGen ()

-- | Compile 'SegRed' instance to host-level code with calls to
-- various kernels.
compileSegRed ::
  Pattern GPUMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegRed :: Pattern GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegRed Pattern GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds KernelBody GPUMem
body =
  Pattern GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 ExpLeaf])] -> InKernelGen ()
red_cont ->
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
reds) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body

      String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        let map_arrs :: [PatElemT LParamMem]
map_arrs = Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
drop ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
reds) ([PatElemT LParamMem] -> [PatElemT LParamMem])
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern GPUMem
PatternT LParamMem
pat
        (PatElemT LParamMem -> KernelResult -> InKernelGen ())
-> [PatElemT LParamMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem GPUMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT LParamMem]
map_arrs [KernelResult]
map_res

      [(SubExp, [TPrimExp Int64 ExpLeaf])] -> InKernelGen ()
red_cont ([(SubExp, [TPrimExp Int64 ExpLeaf])] -> InKernelGen ())
-> [(SubExp, [TPrimExp Int64 ExpLeaf])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [[TPrimExp Int64 ExpLeaf]]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) ([[TPrimExp Int64 ExpLeaf]]
 -> [(SubExp, [TPrimExp Int64 ExpLeaf])])
-> [[TPrimExp Int64 ExpLeaf]]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 ExpLeaf] -> [[TPrimExp Int64 ExpLeaf]]
forall a. a -> [a]
repeat []

-- | Like 'compileSegRed', but where the body is a monadic action.
compileSegRed' ::
  Pattern GPUMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  DoSegBody ->
  CallKernelGen ()
compileSegRed' :: Pattern GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body
  | [SegBinOp GPUMem] -> Int32
forall i a. Num i => [a] -> i
genericLength [SegBinOp GPUMem]
reds Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
> Int32
maxNumOps =
    String -> CallKernelGen ()
forall a. String -> a
compilerLimitationS (String -> CallKernelGen ()) -> String -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      String
"compileSegRed': at most " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int32 -> String
forall a. Show a => a -> String
show Int32
maxNumOps String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" reduction operators are supported."
  | [(VName
_, Constant (IntValue (Int64Value Int64
1))), (VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
    Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern GPUMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body
  | Bool
otherwise = do
    let group_size' :: TPrimExp Int64 ExpLeaf
group_size' = SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp (SubExp -> TPrimExp Int64 ExpLeaf)
-> SubExp -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
        segment_size :: TPrimExp Int64 ExpLeaf
segment_size = SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp (SubExp -> TPrimExp Int64 ExpLeaf)
-> SubExp -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last ([SubExp] -> SubExp) -> [SubExp] -> SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
        use_small_segments :: TPrimExp Bool ExpLeaf
use_small_segments = TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
2 TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
group_size'
    TPrimExp Bool ExpLeaf
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
      TPrimExp Bool ExpLeaf
use_small_segments
      (Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction Pattern GPUMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body)
      (Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern GPUMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body)
  where
    num_groups :: Count NumGroups SubExp
num_groups = SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
    group_size :: Count GroupSize SubExp
group_size = SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

-- | Prepare intermediate arrays for the reduction.  Prim-typed
-- arguments go in local memory (so we need to do the allocation of
-- those arrays inside the kernel), while array-typed arguments go in
-- global memory.  Allocations for the former have already been
-- performed.  This policy is baked into how the allocations are done
-- in ExplicitAllocations.
intermediateArrays ::
  Count GroupSize SubExp ->
  SubExp ->
  SegBinOp GPUMem ->
  InKernelGen [VName]
intermediateArrays :: Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays (Count SubExp
group_size) SubExp
num_threads (SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
nes Shape
_) = do
  let red_op_params :: [LParam GPUMem]
red_op_params = Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op
      ([Param LParamMem]
red_acc_params, [Param LParamMem]
_) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [LParam GPUMem]
[Param LParamMem]
red_op_params
  [Param LParamMem]
-> (Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LParamMem]
red_acc_params ((Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
 -> InKernelGen [VName])
-> (Param LParamMem -> ImpM GPUMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p ->
    case Param LParamMem -> LParamMem
forall dec. Param dec -> dec
paramDec Param LParamMem
p of
      MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
        let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
        String
-> PrimType
-> Shape
-> MemBind
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
sArray String
"red_arr" PrimType
pt Shape
shape' (MemBind -> ImpM GPUMem KernelEnv KernelOp VName)
-> MemBind -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
          VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
      LParamMem
_ -> do
        let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
            shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
        String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"red_arr" PrimType
pt Shape
shape (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

-- | Arrays for storing group results.
--
-- The group-result arrays have an extra dimension because they are
-- also used for keeping vectorised accumulators for first-stage
-- reduction, if necessary.  If necessary, this dimension has size
-- group_size, and otherwise 1.  When actually storing group results,
-- the first index is set to 0.
groupResultArrays ::
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  [SegBinOp GPUMem] ->
  CallKernelGen [[VName]]
groupResultArrays :: Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp GPUMem]
-> CallKernelGen [[VName]]
groupResultArrays (Count SubExp
virt_num_groups) (Count SubExp
group_size) [SegBinOp GPUMem]
reds =
  [SegBinOp GPUMem]
-> (SegBinOp GPUMem -> ImpM GPUMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp GPUMem]
reds ((SegBinOp GPUMem -> ImpM GPUMem HostEnv HostOp [VName])
 -> CallKernelGen [[VName]])
-> (SegBinOp GPUMem -> ImpM GPUMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
lam [SubExp]
_ Shape
shape) ->
    [TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
    -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam) ((TypeBase Shape NoUniqueness -> ImpM GPUMem HostEnv HostOp VName)
 -> ImpM GPUMem HostEnv HostOp [VName])
-> (TypeBase Shape NoUniqueness
    -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
t -> do
      let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
          extra_dim :: SubExp
extra_dim
            | TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
            | Bool
otherwise = SubExp
group_size
          full_shape :: Shape
full_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
extra_dim, SubExp
virt_num_groups] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t
          -- Move the groupsize dimension last to ensure coalesced
          -- memory access.
          perm :: [Int]
perm = [Int
1 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
full_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0]
      String
-> PrimType
-> Shape
-> Space
-> [Int]
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm String
"segred_tmp" PrimType
pt Shape
full_shape (String -> Space
Space String
"device") [Int]
perm

nonsegmentedReduction ::
  Pattern GPUMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  DoSegBody ->
  CallKernelGen ()
nonsegmentedReduction :: Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern GPUMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 ExpLeaf]
dims' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp [SubExp]
dims
      num_groups' :: Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count NumGroups SubExp
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count GroupSize SubExp
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count GroupSize SubExp
group_size
      global_tid :: TPrimExp Int64 ExpLeaf
global_tid = VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 (VName -> TPrimExp Int64 ExpLeaf)
-> VName -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
      w :: TPrimExp Int64 ExpLeaf
w = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a. [a] -> a
last [TPrimExp Int64 ExpLeaf]
dims'

  VName
counter <-
    String
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
      [PrimValue] -> ArrayContents
Imp.ArrayValues ([PrimValue] -> ArrayContents) -> [PrimValue] -> ArrayContents
forall a b. (a -> b) -> a -> b
$ Int -> PrimValue -> [PrimValue]
forall a. Int -> a -> [a]
replicate (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps) (PrimValue -> [PrimValue]) -> PrimValue -> [PrimValue]
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
0

  [[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp GPUMem]
-> CallKernelGen [[VName]]
groupResultArrays Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size [SegBinOp GPUMem]
reds

  TV Int64
num_threads <-
    String
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" (TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
      Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed" Maybe Exp
forall a. Maybe a
Nothing

  String
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_nonseg" Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
    VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
    [[VName]]
reds_arrs <- (SegBinOp GPUMem -> InKernelGen [VName])
-> [SegBinOp GPUMem] -> ImpM GPUMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp GPUMem]
reds

    -- Since this is the nonsegmented case, all outer segment IDs must
    -- necessarily be 0.
    [VName] -> (VName -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
gtids ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TPrimExp Int64 ExpLeaf
0 :: Imp.TExp Int64)

    let num_elements :: Count Elements (TPrimExp Int64 ExpLeaf)
num_elements = TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf)
forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 ExpLeaf
w
        elems_per_thread :: Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread =
          Count Elements (TPrimExp Int64 ExpLeaf)
num_elements
            Count Elements (TPrimExp Int64 ExpLeaf)
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Count Elements (TPrimExp Int64 ExpLeaf)
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf)
forall a. a -> Count Elements a
Imp.elements (TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants))

    [SegBinOpSlug]
slugs <-
      ((SegBinOp GPUMem, [VName], [VName])
 -> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TExp Int32
-> TExp Int32
-> (SegBinOp GPUMem, [VName], [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants) (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)) ([(SegBinOp GPUMem, [VName], [VName])]
 -> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
        [SegBinOp GPUMem]
-> [[VName]] -> [[VName]] -> [(SegBinOp GPUMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
    [Lambda GPUMem]
reds_op_renamed <-
      KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
        KernelConstants
constants
        ([VName]
-> [TPrimExp Int64 ExpLeaf] -> [(VName, TPrimExp Int64 ExpLeaf)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 ExpLeaf]
dims')
        Count Elements (TPrimExp Int64 ExpLeaf)
num_elements
        TPrimExp Int64 ExpLeaf
global_tid
        Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread
        (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_threads)
        [SegBinOpSlug]
slugs
        DoSegBody
body

    let segred_pes :: [[PatElemT LParamMem]]
segred_pes =
          [Int] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp GPUMem -> Int) -> [SegBinOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp GPUMem -> [SubExp]) -> SegBinOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
reds) ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$
            PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern GPUMem
PatternT LParamMem
segred_pat
    [(SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
  SegBinOpSlug, Lambda GPUMem, Integer)]
-> ((SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
     SegBinOpSlug, Lambda GPUMem, Integer)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LParamMem]]
-> [SegBinOpSlug]
-> [Lambda GPUMem]
-> [Integer]
-> [(SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
     SegBinOpSlug, Lambda GPUMem, Integer)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElemT LParamMem]]
segred_pes [SegBinOpSlug]
slugs [Lambda GPUMem]
reds_op_renamed [Integer
0 ..]) (((SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
   SegBinOpSlug, Lambda GPUMem, Integer)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
     SegBinOpSlug, Lambda GPUMem, Integer)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      \(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
nes Shape
_, [VName]
red_arrs, [VName]
group_res_arrs, [PatElemT LParamMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
red_op_renamed, Integer
i) -> do
        let ([Param LParamMem]
red_x_params, [Param LParamMem]
red_y_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op
        KernelConstants
-> [PatElem GPUMem]
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> SegBinOpSlug
-> [LParam GPUMem]
-> [LParam GPUMem]
-> Lambda GPUMem
-> [SubExp]
-> TExp Int32
-> VName
-> TExp Int32
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
          KernelConstants
constants
          [PatElem GPUMem]
[PatElemT LParamMem]
pes
          (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)
          TExp Int32
0
          [TPrimExp Int64 ExpLeaf
0]
          TPrimExp Int64 ExpLeaf
0
          (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 ExpLeaf
kernelNumGroups KernelConstants
constants)
          SegBinOpSlug
slug
          [LParam GPUMem]
[Param LParamMem]
red_x_params
          [LParam GPUMem]
[Param LParamMem]
red_y_params
          Lambda GPUMem
red_op_renamed
          [SubExp]
nes
          TExp Int32
1
          VName
counter
          (Integer -> TExp Int32
forall a. Num a => Integer -> a
fromInteger Integer
i)
          VName
sync_arr
          [VName]
group_res_arrs
          [VName]
red_arrs

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing

smallSegmentsReduction ::
  Pattern GPUMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  DoSegBody ->
  CallKernelGen ()
smallSegmentsReduction :: Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction (Pattern [PatElem GPUMem]
_ [PatElem GPUMem]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 ExpLeaf]
dims' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp [SubExp]
dims
      segment_size :: TPrimExp Int64 ExpLeaf
segment_size = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a. [a] -> a
last [TPrimExp Int64 ExpLeaf]
dims'

  -- Careful to avoid division by zero now.
  TPrimExp Int64 ExpLeaf
segment_size_nonzero <-
    String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segment_size_nonzero" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 ExpLeaf
1 TPrimExp Int64 ExpLeaf
segment_size

  let num_groups' :: Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count NumGroups SubExp
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count GroupSize SubExp
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count GroupSize SubExp
group_size
  TV Int64
num_threads <- String
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" (TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'
  let num_segments :: TPrimExp Int64 ExpLeaf
num_segments = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims'
      segments_per_group :: TPrimExp Int64 ExpLeaf
segments_per_group = Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 ExpLeaf
segment_size_nonzero
      required_groups :: TExp Int32
required_groups = TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 ExpLeaf -> TExp Int32)
-> TPrimExp Int64 ExpLeaf -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf
num_segments TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 ExpLeaf
segments_per_group

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-small" Maybe Exp
forall a. Maybe a
Nothing
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
num_segments
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
segment_size
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segments_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
segments_per_group
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"required_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
required_groups

  String
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_small" Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
    [[VName]]
reds_arrs <- (SegBinOp GPUMem -> InKernelGen [VName])
-> [SegBinOp GPUMem] -> ImpM GPUMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_threads)) [SegBinOp GPUMem]
reds

    -- We probably do not have enough actual workgroups to cover the
    -- entire iteration space.  Some groups thus have to perform double
    -- duty; we put an outer loop to accomplish this.
    SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TExp Int32
required_groups ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id' -> do
      -- Compute the 'n' input indices.  The outer 'n-1' correspond to
      -- the segment ID, and are computed from the group id.  The inner
      -- is computed from the local thread id, and may be out-of-bounds.
      let ltid :: TPrimExp Int64 ExpLeaf
ltid = TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 ExpLeaf)
-> TExp Int32 -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
          segment_index :: TPrimExp Int64 ExpLeaf
segment_index =
            (TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 ExpLeaf
segment_size_nonzero)
              TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ (TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
segments_per_group)
          index_within_segment :: TPrimExp Int64 ExpLeaf
index_within_segment = TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 ExpLeaf
segment_size

      [(VName, TPrimExp Int64 ExpLeaf)]
-> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 ExpLeaf)]
-> TPrimExp Int64 ExpLeaf -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 ExpLeaf] -> [(VName, TPrimExp Int64 ExpLeaf)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids) ([TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims')) TPrimExp Int64 ExpLeaf
segment_index
      VName -> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) TPrimExp Int64 ExpLeaf
index_within_segment

      let out_of_bounds :: InKernelGen ()
out_of_bounds =
            [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
reds [[VName]]
reds_arrs) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
_ [SubExp]
nes Shape
_, [VName]
red_arrs) ->
              [(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
red_arrs [SubExp]
nes) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
ne) ->
                VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 ExpLeaf
ltid] SubExp
ne []

          in_bounds :: InKernelGen ()
in_bounds =
            DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 ExpLeaf])]
red_res ->
              String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save results to be reduced" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                let red_dests :: [(VName, [TPrimExp Int64 ExpLeaf])]
red_dests = [VName]
-> [[TPrimExp Int64 ExpLeaf]]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs) ([[TPrimExp Int64 ExpLeaf]] -> [(VName, [TPrimExp Int64 ExpLeaf])])
-> [[TPrimExp Int64 ExpLeaf]]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 ExpLeaf] -> [[TPrimExp Int64 ExpLeaf]]
forall a. a -> [a]
repeat [TPrimExp Int64 ExpLeaf
ltid]
                [((VName, [TPrimExp Int64 ExpLeaf]),
  (SubExp, [TPrimExp Int64 ExpLeaf]))]
-> (((VName, [TPrimExp Int64 ExpLeaf]),
     (SubExp, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TPrimExp Int64 ExpLeaf])]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
-> [((VName, [TPrimExp Int64 ExpLeaf]),
     (SubExp, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [TPrimExp Int64 ExpLeaf])]
red_dests [(SubExp, [TPrimExp Int64 ExpLeaf])]
red_res) ((((VName, [TPrimExp Int64 ExpLeaf]),
   (SubExp, [TPrimExp Int64 ExpLeaf]))
  -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]),
     (SubExp, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
d, [TPrimExp Int64 ExpLeaf]
d_is), (SubExp
res, [TPrimExp Int64 ExpLeaf]
res_is)) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
d [TPrimExp Int64 ExpLeaf]
d_is SubExp
res [TPrimExp Int64 ExpLeaf]
res_is

      String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"apply map function if in bounds" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        TPrimExp Bool ExpLeaf
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
          ( TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 ExpLeaf
0
              TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(VName, SubExp)] -> TPrimExp Bool ExpLeaf
isActive ([(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [SubExp]
dims)
              TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segments_per_group
          )
          InKernelGen ()
in_bounds
          InKernelGen ()
out_of_bounds

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal -- Also implicitly barrier.
      let crossesSegment :: TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
crossesSegment TExp Int32
from TExp Int32
to =
            (TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 ExpLeaf
segment_size)
      TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 ExpLeaf
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform segmented scan to imitate reduction" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
reds [[VName]]
reds_arrs) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
_ Shape
_, [VName]
red_arrs) ->
            Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
              ((TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
crossesSegment)
              (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
num_threads)
              (TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segments_per_group)
              Lambda GPUMem
red_op
              [VName]
red_arrs

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

      String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save final values of segments" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen
          ( TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segments_per_group TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
num_segments
              TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
segments_per_group
          )
          (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(PatElemT LParamMem, VName)]
-> ((PatElemT LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [VName] -> [(PatElemT LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem GPUMem]
[PatElemT LParamMem]
segred_pes ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)) (((PatElemT LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElemT LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, VName
arr) -> do
            -- Figure out which segment result this thread should write...
            let flat_segment_index :: TPrimExp Int64 ExpLeaf
flat_segment_index =
                  TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segments_per_group TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
ltid
                gtids' :: [TPrimExp Int64 ExpLeaf]
gtids' =
                  [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf -> [TPrimExp Int64 ExpLeaf]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims') TPrimExp Int64 ExpLeaf
flat_segment_index
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix
              (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
              [TPrimExp Int64 ExpLeaf]
gtids'
              (VName -> SubExp
Var VName
arr)
              [(TPrimExp Int64 ExpLeaf
ltid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
1) TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
segment_size_nonzero TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1]

      -- Finally another barrier, because we will be writing to the
      -- local memory array first thing in the next iteration.
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing

largeSegmentsReduction ::
  Pattern GPUMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  DoSegBody ->
  CallKernelGen ()
largeSegmentsReduction :: Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern GPUMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 ExpLeaf]
dims' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp [SubExp]
dims
      num_segments :: TPrimExp Int64 ExpLeaf
num_segments = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims'
      segment_size :: TPrimExp Int64 ExpLeaf
segment_size = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a. [a] -> a
last [TPrimExp Int64 ExpLeaf]
dims'
      num_groups' :: Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count NumGroups SubExp
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' = (SubExp -> TPrimExp Int64 ExpLeaf)
-> Count GroupSize SubExp
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp Count GroupSize SubExp
group_size

  (TPrimExp Int64 ExpLeaf
groups_per_segment, Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread) <-
    TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> CallKernelGen
     (TPrimExp Int64 ExpLeaf, Count Elements (TPrimExp Int64 ExpLeaf))
groupsPerSegmentAndElementsPerThread
      TPrimExp Int64 ExpLeaf
segment_size
      TPrimExp Int64 ExpLeaf
num_segments
      Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups'
      Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'
  TV Int64
virt_num_groups <-
    String
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"virt_num_groups" (TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
num_segments

  TV Int64
num_threads <-
    String
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" (TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
      Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'

  TV Int64
threads_per_segment <-
    String
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"threads_per_segment" (TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-large" Maybe Exp
forall a. Maybe a
Nothing
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
num_segments
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
segment_size
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"virt_num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
virt_num_groups
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups'
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"group_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size'
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"elems_per_thread" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ Count Elements (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"groups_per_segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 ExpLeaf
groups_per_segment

  [[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp GPUMem]
-> CallKernelGen [[VName]]
groupResultArrays (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
virt_num_groups)) Count GroupSize SubExp
group_size [SegBinOp GPUMem]
reds

  -- In principle we should have a counter for every segment.  Since
  -- the number of segments is a dynamic quantity, we would have to
  -- allocate and zero out an array here, which is expensive.
  -- However, we exploit the fact that the number of segments being
  -- reduced at any point in time is limited by the number of
  -- workgroups. If we bound the number of workgroups, we can get away
  -- with using that many counters.  FIXME: Is this limit checked
  -- anywhere?  There are other places in the compiler that will fail
  -- if the group count exceeds the maximum group size, which is at
  -- most 1024 anyway.
  let num_counters :: Int
num_counters = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024
  VName
counter <-
    String
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
      Int -> ArrayContents
Imp.ArrayZeros Int
num_counters

  String
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_large" Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups' Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
    [[VName]]
reds_arrs <- (SegBinOp GPUMem -> InKernelGen [VName])
-> [SegBinOp GPUMem] -> ImpM GPUMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp GPUMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp GPUMem]
reds
    VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

    -- We probably do not have enough actual workgroups to cover the
    -- entire iteration space.  Some groups thus have to perform double
    -- duty; we put an outer loop to accomplish this.
    SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
virt_num_groups)) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
      let segment_gtids :: [VName]
segment_gtids = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids
          w :: SubExp
w = [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
dims
          local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants

      TExp Int32
flat_segment_id <-
        String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_segment_id" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
          TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 ExpLeaf
groups_per_segment

      TPrimExp Int64 ExpLeaf
global_tid <-
        String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
          (TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size') TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid)
            TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`rem` (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size') TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
groups_per_segment)

      let first_group_for_segment :: TPrimExp Int64 ExpLeaf
first_group_for_segment = TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
groups_per_segment
      [(VName, TPrimExp Int64 ExpLeaf)]
-> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 ExpLeaf)]
-> TPrimExp Int64 ExpLeaf -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 ExpLeaf] -> [(VName, TPrimExp Int64 ExpLeaf)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
segment_gtids ([TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a]
init [TPrimExp Int64 ExpLeaf]
dims')) (TPrimExp Int64 ExpLeaf -> InKernelGen ())
-> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id
      VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int64
      let num_elements :: Count Elements (TPrimExp Int64 ExpLeaf)
num_elements = TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf)
forall a. a -> Count Elements a
Imp.elements (TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp SubExp
w

      [SegBinOpSlug]
slugs <-
        ((SegBinOp GPUMem, [VName], [VName])
 -> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TExp Int32
-> TExp Int32
-> (SegBinOp GPUMem, [VName], [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TExp Int32
local_tid TExp Int32
group_id) ([(SegBinOp GPUMem, [VName], [VName])]
 -> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp GPUMem, [VName], [VName])]
-> ImpM GPUMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
          [SegBinOp GPUMem]
-> [[VName]] -> [[VName]] -> [(SegBinOp GPUMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
      [Lambda GPUMem]
reds_op_renamed <-
        KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
          KernelConstants
constants
          ([VName]
-> [TPrimExp Int64 ExpLeaf] -> [(VName, TPrimExp Int64 ExpLeaf)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TPrimExp Int64 ExpLeaf]
dims')
          Count Elements (TPrimExp Int64 ExpLeaf)
num_elements
          TPrimExp Int64 ExpLeaf
global_tid
          Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread
          (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
threads_per_segment)
          [SegBinOpSlug]
slugs
          DoSegBody
body

      let segred_pes :: [[PatElemT LParamMem]]
segred_pes =
            [Int] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp GPUMem -> Int) -> [SegBinOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp GPUMem -> [SubExp]) -> SegBinOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
reds) ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$
              PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern GPUMem
PatternT LParamMem
segred_pat

          multiple_groups_per_segment :: InKernelGen ()
multiple_groups_per_segment =
            [(SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
  SegBinOpSlug, Lambda GPUMem, Integer)]
-> ((SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
     SegBinOpSlug, Lambda GPUMem, Integer)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LParamMem]]
-> [SegBinOpSlug]
-> [Lambda GPUMem]
-> [Integer]
-> [(SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
     SegBinOpSlug, Lambda GPUMem, Integer)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegBinOp GPUMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElemT LParamMem]]
segred_pes [SegBinOpSlug]
slugs [Lambda GPUMem]
reds_op_renamed [Integer
0 ..]) (((SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
   SegBinOpSlug, Lambda GPUMem, Integer)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName], [VName], [PatElemT LParamMem],
     SegBinOpSlug, Lambda GPUMem, Integer)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              \(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
nes Shape
_, [VName]
red_arrs, [VName]
group_res_arrs, [PatElemT LParamMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
red_op_renamed, Integer
i) -> do
                let ([Param LParamMem]
red_x_params, [Param LParamMem]
red_y_params) =
                      Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op
                KernelConstants
-> [PatElem GPUMem]
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> SegBinOpSlug
-> [LParam GPUMem]
-> [LParam GPUMem]
-> Lambda GPUMem
-> [SubExp]
-> TExp Int32
-> VName
-> TExp Int32
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
                  KernelConstants
constants
                  [PatElem GPUMem]
[PatElemT LParamMem]
pes
                  TExp Int32
group_id
                  TExp Int32
flat_segment_id
                  ((VName -> TPrimExp Int64 ExpLeaf)
-> [VName] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 [VName]
segment_gtids)
                  (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
first_group_for_segment)
                  TPrimExp Int64 ExpLeaf
groups_per_segment
                  SegBinOpSlug
slug
                  [LParam GPUMem]
[Param LParamMem]
red_x_params
                  [LParam GPUMem]
[Param LParamMem]
red_y_params
                  Lambda GPUMem
red_op_renamed
                  [SubExp]
nes
                  (Int -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters)
                  VName
counter
                  (Integer -> TExp Int32
forall a. Num a => Integer -> a
fromInteger Integer
i)
                  VName
sync_arr
                  [VName]
group_res_arrs
                  [VName]
red_arrs

          one_group_per_segment :: InKernelGen ()
one_group_per_segment =
            String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"first thread in group saves final result to memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [(SegBinOpSlug, [PatElemT LParamMem])]
-> ((SegBinOpSlug, [PatElemT LParamMem]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[PatElemT LParamMem]] -> [(SegBinOpSlug, [PatElemT LParamMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElemT LParamMem]]
segred_pes) (((SegBinOpSlug, [PatElemT LParamMem]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOpSlug, [PatElemT LParamMem]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElemT LParamMem]
pes) ->
                TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  [(PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
pes (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
v, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
                    VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
v) ((VName -> TPrimExp Int64 ExpLeaf)
-> [VName] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 [VName]
segment_gtids) (VName -> SubExp
Var VName
acc) [TPrimExp Int64 ExpLeaf]
acc_is

      TPrimExp Bool ExpLeaf
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
1) InKernelGen ()
one_group_per_segment InKernelGen ()
multiple_groups_per_segment

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing

-- Careful to avoid division by zero here.  We have at least one group
-- per segment.
groupsPerSegmentAndElementsPerThread ::
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Count NumGroups (Imp.TExp Int64) ->
  Count GroupSize (Imp.TExp Int64) ->
  CallKernelGen
    ( Imp.TExp Int64,
      Imp.Count Imp.Elements (Imp.TExp Int64)
    )
groupsPerSegmentAndElementsPerThread :: TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> Count NumGroups (TPrimExp Int64 ExpLeaf)
-> Count GroupSize (TPrimExp Int64 ExpLeaf)
-> CallKernelGen
     (TPrimExp Int64 ExpLeaf, Count Elements (TPrimExp Int64 ExpLeaf))
groupsPerSegmentAndElementsPerThread TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
num_segments Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups_hint Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size = do
  TPrimExp Int64 ExpLeaf
groups_per_segment <-
    String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"groups_per_segment" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
      Count NumGroups (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 ExpLeaf)
num_groups_hint TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 ExpLeaf
1 TPrimExp Int64 ExpLeaf
num_segments
  TPrimExp Int64 ExpLeaf
elements_per_thread <-
    String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"elements_per_thread" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 ExpLeaf
segment_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` (Count GroupSize (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 ExpLeaf)
group_size TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
groups_per_segment)
  (TPrimExp Int64 ExpLeaf, Count Elements (TPrimExp Int64 ExpLeaf))
-> CallKernelGen
     (TPrimExp Int64 ExpLeaf, Count Elements (TPrimExp Int64 ExpLeaf))
forall (m :: * -> *) a. Monad m => a -> m a
return (TPrimExp Int64 ExpLeaf
groups_per_segment, TPrimExp Int64 ExpLeaf -> Count Elements (TPrimExp Int64 ExpLeaf)
forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 ExpLeaf
elements_per_thread)

-- | A SegBinOp with auxiliary information.
data SegBinOpSlug = SegBinOpSlug
  { SegBinOpSlug -> SegBinOp GPUMem
slugOp :: SegBinOp GPUMem,
    -- | The arrays used for computing the intra-group reduction
    -- (either local or global memory).
    SegBinOpSlug -> [VName]
slugArrs :: [VName],
    -- | Places to store accumulator in stage 1 reduction.
    SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs :: [(VName, [Imp.TExp Int64])]
  }

slugBody :: SegBinOpSlug -> Body GPUMem
slugBody :: SegBinOpSlug -> Body GPUMem
slugBody = Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody (Lambda GPUMem -> Body GPUMem)
-> (SegBinOpSlug -> Lambda GPUMem) -> SegBinOpSlug -> Body GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda (SegBinOp GPUMem -> Lambda GPUMem)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Lambda GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp

slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams = Lambda GPUMem -> [Param LParamMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [Param LParamMem])
-> (SegBinOpSlug -> Lambda GPUMem)
-> SegBinOpSlug
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda (SegBinOp GPUMem -> Lambda GPUMem)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Lambda GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp

slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral (SegBinOp GPUMem -> [SubExp])
-> (SegBinOpSlug -> SegBinOp GPUMem) -> SegBinOpSlug -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp

slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape (SegBinOp GPUMem -> Shape)
-> (SegBinOpSlug -> SegBinOp GPUMem) -> SegBinOpSlug -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp

slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ([Commutativity] -> Commutativity)
-> ([SegBinOpSlug] -> [Commutativity])
-> [SegBinOpSlug]
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOpSlug -> Commutativity)
-> [SegBinOpSlug] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map (SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm (SegBinOp GPUMem -> Commutativity)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp)

accParams, nextParams :: SegBinOpSlug -> [LParam GPUMem]
accParams :: SegBinOpSlug -> [LParam GPUMem]
accParams SegBinOpSlug
slug = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam GPUMem]
nextParams SegBinOpSlug
slug = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug

segBinOpSlug :: Imp.TExp Int32 -> Imp.TExp Int32 -> (SegBinOp GPUMem, [VName], [VName]) -> InKernelGen SegBinOpSlug
segBinOpSlug :: TExp Int32
-> TExp Int32
-> (SegBinOp GPUMem, [VName], [VName])
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TExp Int32
local_tid TExp Int32
group_id (SegBinOp GPUMem
op, [VName]
group_res_arrs, [VName]
param_arrs) =
  SegBinOp GPUMem
-> [VName] -> [(VName, [TPrimExp Int64 ExpLeaf])] -> SegBinOpSlug
SegBinOpSlug SegBinOp GPUMem
op [VName]
group_res_arrs
    ([(VName, [TPrimExp Int64 ExpLeaf])] -> SegBinOpSlug)
-> ImpM
     GPUMem KernelEnv KernelOp [(VName, [TPrimExp Int64 ExpLeaf])]
-> ImpM GPUMem KernelEnv KernelOp SegBinOpSlug
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param LParamMem
 -> VName
 -> ImpM
      GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf]))
-> [Param LParamMem]
-> [VName]
-> ImpM
     GPUMem KernelEnv KernelOp [(VName, [TPrimExp Int64 ExpLeaf])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param LParamMem
-> VName
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf])
mkAcc (Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)) [VName]
param_arrs
  where
    mkAcc :: Param LParamMem
-> VName
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf])
mkAcc Param LParamMem
p VName
param_arr
      | Prim PrimType
t <- Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p,
        Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
op) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = do
        TV Any
acc <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim (VName -> String
baseString (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_acc") PrimType
t
        (VName, [TPrimExp Int64 ExpLeaf])
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf])
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc, [])
      | Bool
otherwise =
        (VName, [TPrimExp Int64 ExpLeaf])
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 ExpLeaf])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
param_arr, [TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid, TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id])

reductionStageZero ::
  KernelConstants ->
  [(VName, Imp.TExp Int64)] ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.TExp Int64 ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  VName ->
  [SegBinOpSlug] ->
  DoSegBody ->
  InKernelGen ([Lambda GPUMem], InKernelGen ())
reductionStageZero :: KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda GPUMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TPrimExp Int64 ExpLeaf)]
ispace Count Elements (TPrimExp Int64 ExpLeaf)
num_elements TPrimExp Int64 ExpLeaf
global_tid Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
  let ([VName]
gtids, [TPrimExp Int64 ExpLeaf]
_dims) = [(VName, TPrimExp Int64 ExpLeaf)]
-> ([VName], [TPrimExp Int64 ExpLeaf])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, TPrimExp Int64 ExpLeaf)]
ispace
      gtid :: TV Int64
gtid = VName -> PrimType -> TV Int64
forall t. VName -> PrimType -> TV t
mkTV ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int64
      local_tid :: TPrimExp Int64 ExpLeaf
local_tid = TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 ExpLeaf)
-> TExp Int32 -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants

  -- Figure out how many elements this thread should process.
  TV Int64
chunk_size <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"chunk_size" PrimType
int64
  let ordering :: SplitOrdering
ordering = case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
        Commutativity
Commutative -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
threads_per_segment
        Commutativity
Noncommutative -> SplitOrdering
SplitContiguous
  SplitOrdering
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TV Int64
-> InKernelGen ()
forall rep r op.
SplitOrdering
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TV Int64
-> ImpM rep r op ()
computeThreadChunkSize SplitOrdering
ordering (TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
global_tid) Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread Count Elements (TPrimExp Int64 ExpLeaf)
num_elements TV Int64
chunk_size

  Maybe (Exp GPUMem) -> Scope GPUMem -> InKernelGen ()
forall rep r op.
Mem rep =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp GPUMem)
forall a. Maybe a
Nothing (Scope GPUMem -> InKernelGen ()) -> Scope GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param LParamMem] -> Scope GPUMem)
-> [Param LParamMem] -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LParamMem])
-> [SegBinOpSlug] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam GPUMem]
SegBinOpSlug -> [Param LParamMem]
slugParams [SegBinOpSlug]
slugs

  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"neutral-initialise the accumulators" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    [SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
      [((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TPrimExp Int64 ExpLeaf])]
-> [SubExp] -> [((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is), SubExp
ne) ->
        Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is ->
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is) SubExp
ne []

  [Lambda GPUMem]
slugs_op_renamed <- (SegBinOpSlug -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem))
-> [SegBinOpSlug] -> InKernelGen [Lambda GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem))
-> (SegBinOpSlug -> Lambda GPUMem)
-> SegBinOpSlug
-> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda (SegBinOp GPUMem -> Lambda GPUMem)
-> (SegBinOpSlug -> SegBinOp GPUMem)
-> SegBinOpSlug
-> Lambda GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp) [SegBinOpSlug]
slugs

  let doTheReduction :: InKernelGen ()
doTheReduction =
        [(Lambda GPUMem, SegBinOpSlug)]
-> ((Lambda GPUMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Lambda GPUMem]
-> [SegBinOpSlug] -> [(Lambda GPUMem, SegBinOpSlug)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Lambda GPUMem]
slugs_op_renamed [SegBinOpSlug]
slugs) (((Lambda GPUMem, SegBinOpSlug) -> InKernelGen ())
 -> InKernelGen ())
-> ((Lambda GPUMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Lambda GPUMem
slug_op_renamed, SegBinOpSlug
slug) ->
          Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is -> do
            String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"to reduce current chunk, first store our result in memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              [(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
                VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)

              [(VName, Param LParamMem)]
-> ((VName, Param LParamMem) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Param LParamMem] -> [(VName, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug) (SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug)) (((VName, Param LParamMem) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, Param LParamMem) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LParamMem
p) ->
                Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 ExpLeaf
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []

            KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal -- Also implicitly barrier.
            TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants)) Lambda GPUMem
slug_op_renamed (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug)

            KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

            String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"first thread saves the result in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 ExpLeaf
local_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)]
-> (((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TPrimExp Int64 ExpLeaf])]
-> [Param LParamMem]
-> [((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug) (Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
slug_op_renamed)) ((((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)
  -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]), Param LParamMem)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is), Param LParamMem
p) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []

  -- If this is a non-commutative reduction, each thread must run the
  -- loop the same number of iterations, because we will be performing
  -- a group-wide reduction in there.
  let comm :: Commutativity
comm = [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs
      (TPrimExp Int64 ExpLeaf
bound, InKernelGen () -> InKernelGen ()
check_bounds) =
        case Commutativity
comm of
          Commutativity
Commutative -> (TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
chunk_size, InKernelGen () -> InKernelGen ()
forall a. a -> a
id)
          Commutativity
Noncommutative ->
            ( Count Elements (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread,
              TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
gtid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. Count Elements (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 ExpLeaf)
num_elements)
            )

  String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> InKernelGen ())
-> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
bound ((TPrimExp Int64 ExpLeaf -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 ExpLeaf -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
    TV Int64
gtid
      TV Int64 -> TPrimExp Int64 ExpLeaf -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- case Commutativity
comm of
        Commutativity
Commutative ->
          TPrimExp Int64 ExpLeaf
global_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 ExpLeaf
Imp.vi64 VName
threads_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
i
        Commutativity
Noncommutative ->
          let index_in_segment :: TPrimExp Int64 ExpLeaf
index_in_segment = TPrimExp Int64 ExpLeaf
global_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants
           in TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
local_tid
                TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ (TPrimExp Int64 ExpLeaf
index_in_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* Count Elements (TPrimExp Int64 ExpLeaf) -> TPrimExp Int64 ExpLeaf
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
i)
                TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants

    InKernelGen () -> InKernelGen ()
check_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"apply map function" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 ExpLeaf])]
all_red_res -> do
          let slugs_res :: [[(SubExp, [TPrimExp Int64 ExpLeaf])]]
slugs_res = [Int]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
-> [[(SubExp, [TPrimExp Int64 ExpLeaf])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOpSlug -> Int) -> [SegBinOpSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOpSlug -> [SubExp]) -> SegBinOpSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [TPrimExp Int64 ExpLeaf])]
all_red_res

          [(SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])]
-> ((SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[(SubExp, [TPrimExp Int64 ExpLeaf])]]
-> [(SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[(SubExp, [TPrimExp Int64 ExpLeaf])]]
slugs_res) (((SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOpSlug, [(SubExp, [TPrimExp Int64 ExpLeaf])])
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [(SubExp, [TPrimExp Int64 ExpLeaf])]
red_res) ->
            Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is -> do
              String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"load accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)
              String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"load new values" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [(Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))]
-> ((Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(SubExp, [TPrimExp Int64 ExpLeaf])]
-> [(Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
nextParams SegBinOpSlug
slug) [(SubExp, [TPrimExp Int64 ExpLeaf])]
red_res) (((Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LParamMem, (SubExp, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
res, [TPrimExp Int64 ExpLeaf]
res_is)) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TPrimExp Int64 ExpLeaf]
res_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)
              String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"apply reduction operator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"store in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                    [((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
                      ( [(VName, [TPrimExp Int64 ExpLeaf])]
-> [SubExp] -> [((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip
                          (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)
                          (Body GPUMem -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (Body GPUMem -> [SubExp]) -> Body GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug)
                      )
                      ((((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is), SubExp
se) ->
                        VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is) SubExp
se []

    case Commutativity
comm of
      Commutativity
Noncommutative -> do
        InKernelGen ()
doTheReduction
        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"first thread keeps accumulator; others reset to neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          let reset_to_neutral :: InKernelGen ()
reset_to_neutral =
                [SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
                  [((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TPrimExp Int64 ExpLeaf])]
-> [SubExp] -> [((VName, [TPrimExp Int64 ExpLeaf]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TPrimExp Int64 ExpLeaf]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is), SubExp
ne) ->
                    Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is ->
                      VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 ExpLeaf]
acc_is [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is) SubExp
ne []
          TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TPrimExp Int64 ExpLeaf
local_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
0) InKernelGen ()
reset_to_neutral
      Commutativity
_ -> () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  ([Lambda GPUMem], InKernelGen ())
-> InKernelGen ([Lambda GPUMem], InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([Lambda GPUMem]
slugs_op_renamed, InKernelGen ()
doTheReduction)

reductionStageOne ::
  KernelConstants ->
  [(VName, Imp.TExp Int64)] ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.TExp Int64 ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  VName ->
  [SegBinOpSlug] ->
  DoSegBody ->
  InKernelGen [Lambda GPUMem]
reductionStageOne :: KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne KernelConstants
constants [(VName, TPrimExp Int64 ExpLeaf)]
ispace Count Elements (TPrimExp Int64 ExpLeaf)
num_elements TPrimExp Int64 ExpLeaf
global_tid Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
  ([Lambda GPUMem]
slugs_op_renamed, InKernelGen ()
doTheReduction) <-
    KernelConstants
-> [(VName, TPrimExp Int64 ExpLeaf)]
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> TPrimExp Int64 ExpLeaf
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda GPUMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TPrimExp Int64 ExpLeaf)]
ispace Count Elements (TPrimExp Int64 ExpLeaf)
num_elements TPrimExp Int64 ExpLeaf
global_tid Count Elements (TPrimExp Int64 ExpLeaf)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body

  case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
    Commutativity
Noncommutative ->
      [SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
        [(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam GPUMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LParamMem, (VName, [TPrimExp Int64 ExpLeaf]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) [TPrimExp Int64 ExpLeaf]
acc_is
    Commutativity
_ -> InKernelGen ()
doTheReduction

  [Lambda GPUMem] -> InKernelGen [Lambda GPUMem]
forall (m :: * -> *) a. Monad m => a -> m a
return [Lambda GPUMem]
slugs_op_renamed

reductionStageTwo ::
  KernelConstants ->
  [PatElem GPUMem] ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  [Imp.TExp Int64] ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  SegBinOpSlug ->
  [LParam GPUMem] ->
  [LParam GPUMem] ->
  Lambda GPUMem ->
  [SubExp] ->
  Imp.TExp Int32 ->
  VName ->
  Imp.TExp Int32 ->
  VName ->
  [VName] ->
  [VName] ->
  InKernelGen ()
reductionStageTwo :: KernelConstants
-> [PatElem GPUMem]
-> TExp Int32
-> TExp Int32
-> [TPrimExp Int64 ExpLeaf]
-> TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf
-> SegBinOpSlug
-> [LParam GPUMem]
-> [LParam GPUMem]
-> Lambda GPUMem
-> [SubExp]
-> TExp Int32
-> VName
-> TExp Int32
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
  KernelConstants
constants
  [PatElem GPUMem]
segred_pes
  TExp Int32
group_id
  TExp Int32
flat_segment_id
  [TPrimExp Int64 ExpLeaf]
segment_gtids
  TPrimExp Int64 ExpLeaf
first_group_for_segment
  TPrimExp Int64 ExpLeaf
groups_per_segment
  SegBinOpSlug
slug
  [LParam GPUMem]
red_x_params
  [LParam GPUMem]
red_y_params
  Lambda GPUMem
red_op_renamed
  [SubExp]
nes
  TExp Int32
num_counters
  VName
counter
  TExp Int32
counter_i
  VName
sync_arr
  [VName]
group_res_arrs
  [VName]
red_arrs = do
    let local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
        group_size :: TPrimExp Int64 ExpLeaf
group_size = KernelConstants -> TPrimExp Int64 ExpLeaf
kernelGroupSize KernelConstants
constants
    TV Int64
old_counter <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"old_counter" PrimType
int32
    (VName
counter_mem, Space
_, Count Elements (TPrimExp Int64 ExpLeaf)
counter_offset) <-
      VName
-> [TPrimExp Int64 ExpLeaf]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 ExpLeaf))
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 ExpLeaf))
fullyIndexArray
        VName
counter
        [ TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 ExpLeaf)
-> TExp Int32 -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$
            TExp Int32
counter_i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
num_counters
              TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
flat_segment_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
num_counters
        ]
    String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"first thread in group saves group result to global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
-> ((VName, (VName, [TPrimExp Int64 ExpLeaf])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Int
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
 -> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))])
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. (a -> b) -> a -> b
$ [VName]
-> [(VName, [TPrimExp Int64 ExpLeaf])]
-> [(VName, (VName, [TPrimExp Int64 ExpLeaf]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
group_res_arrs (SegBinOpSlug -> [(VName, [TPrimExp Int64 ExpLeaf])]
slugAccs SegBinOpSlug
slug)) (((VName, (VName, [TPrimExp Int64 ExpLeaf])) -> InKernelGen ())
 -> InKernelGen ())
-> ((VName, (VName, [TPrimExp Int64 ExpLeaf])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, (VName
acc, [TPrimExp Int64 ExpLeaf]
acc_is)) ->
          VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
v [TPrimExp Int64 ExpLeaf
0, TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id] (VName -> SubExp
Var VName
acc) [TPrimExp Int64 ExpLeaf]
acc_is
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
        -- Increment the counter, thus stating that our result is
        -- available.
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
            IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Exp
-> AtomicOp
Imp.AtomicAdd
              IntType
Int32
              (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
old_counter)
              VName
counter_mem
              Count Elements (TPrimExp Int64 ExpLeaf)
counter_offset
              (Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32
1 :: Imp.TExp Int32)
        -- Now check if we were the last group to write our result.  If
        -- so, it is our responsibility to produce the final result.
        VName -> [TPrimExp Int64 ExpLeaf] -> Exp -> InKernelGen ()
forall rep r op.
VName -> [TPrimExp Int64 ExpLeaf] -> Exp -> ImpM rep r op ()
sWrite VName
sync_arr [TPrimExp Int64 ExpLeaf
0] (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool ExpLeaf -> Exp) -> TPrimExp Bool ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int64
old_counter TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
- TPrimExp Int64 ExpLeaf
1

    KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

    TV Bool
is_last_group <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Bool)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"is_last_group" PrimType
Bool
    VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (TV Bool -> VName
forall t. TV t -> VName
tvVar TV Bool
is_last_group) [] (VName -> SubExp
Var VName
sync_arr) [TPrimExp Int64 ExpLeaf
0]
    TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TV Bool -> TPrimExp Bool ExpLeaf
forall t. TV t -> TExp t
tvExp TV Bool
is_last_group) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      -- The final group has written its result (and it was
      -- us!), so read in all the group results and perform the
      -- final stage of the reduction.  But first, we reset the
      -- counter so it is ready for next time.  This is done
      -- with an atomic to avoid warnings about write/write
      -- races in oclgrind.
      TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
            IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 ExpLeaf)
-> Exp
-> AtomicOp
Imp.AtomicAdd IntType
Int32 (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
old_counter) VName
counter_mem Count Elements (TPrimExp Int64 ExpLeaf)
counter_offset (Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$
              TPrimExp Int64 ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 ExpLeaf -> Exp) -> TPrimExp Int64 ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a
negate TPrimExp Int64 ExpLeaf
groups_per_segment

      Shape
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 ExpLeaf] -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 ExpLeaf] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 ExpLeaf]
vec_is -> do
        -- There is no guarantee that the number of workgroups for the
        -- segment is less than the workgroup size, so each thread may
        -- have to read multiple elements.  We do this in a sequential
        -- way that may induce non-coalesced accesses, but the total
        -- number of accesses should be tiny here.
        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"read in the per-group-results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          TPrimExp Int64 ExpLeaf
read_per_thread <-
            String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"read_per_thread" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
              TPrimExp Int64 ExpLeaf
groups_per_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 ExpLeaf
group_size

          [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
[Param LParamMem]
red_x_params [SubExp]
nes) (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []

          String
-> TPrimExp Int64 ExpLeaf
-> (TPrimExp Int64 ExpLeaf -> InKernelGen ())
-> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 ExpLeaf
read_per_thread ((TPrimExp Int64 ExpLeaf -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 ExpLeaf -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 ExpLeaf
i -> do
            TPrimExp Int64 ExpLeaf
group_res_id <-
              String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_res_id" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
                TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
* TPrimExp Int64 ExpLeaf
read_per_thread TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
i
            TPrimExp Int64 ExpLeaf
index_of_group_res <-
              String
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"index_of_group_res" (TPrimExp Int64 ExpLeaf
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf))
-> TPrimExp Int64 ExpLeaf
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
                TPrimExp Int64 ExpLeaf
first_group_for_segment TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf
forall a. Num a => a -> a -> a
+ TPrimExp Int64 ExpLeaf
group_res_id

            TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 ExpLeaf
group_res_id TPrimExp Int64 ExpLeaf
-> TPrimExp Int64 ExpLeaf -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
groups_per_segment) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
[Param LParamMem]
red_y_params [VName]
group_res_arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                \(Param LParamMem
p, VName
group_res_arr) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix
                    (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
                    []
                    (VName -> SubExp
Var VName
group_res_arr)
                    ([TPrimExp Int64 ExpLeaf
0, TPrimExp Int64 ExpLeaf
index_of_group_res] [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)

              Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
[Param LParamMem]
red_x_params (Body GPUMem -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (Body GPUMem -> [SubExp]) -> Body GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug)) (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
se) ->
                  VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
se []

        [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam GPUMem]
[Param LParamMem]
red_x_params [VName]
red_arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
arr) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int32 -> TPrimExp Int64 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"reduce the per-group results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (TPrimExp Int64 ExpLeaf -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 ExpLeaf
group_size) Lambda GPUMem
red_op_renamed [VName]
red_arrs

          String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"and back to memory with the final result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall rep r op.
TPrimExp Bool ExpLeaf -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [(PatElemT LParamMem, Param LParamMem)]
-> ((PatElemT LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [Param LParamMem] -> [(PatElemT LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem GPUMem]
[PatElemT LParamMem]
segred_pes ([Param LParamMem] -> [(PatElemT LParamMem, Param LParamMem)])
-> [Param LParamMem] -> [(PatElemT LParamMem, Param LParamMem)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
red_op_renamed) (((PatElemT LParamMem, Param LParamMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, Param LParamMem
p) ->
                VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 ExpLeaf]
-> SubExp
-> [TPrimExp Int64 ExpLeaf]
-> ImpM rep r op ()
copyDWIMFix
                  (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
                  ([TPrimExp Int64 ExpLeaf]
segment_gtids [TPrimExp Int64 ExpLeaf]
-> [TPrimExp Int64 ExpLeaf] -> [TPrimExp Int64 ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 ExpLeaf]
vec_is)
                  (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
                  []