{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | We generate code for non-segmented/single-segment SegRed using
-- the basic approach outlined in the paper "Design and GPGPU
-- Performance of Futhark’s Redomap Construct" (ARRAY '16).  The main
-- deviations are:
--
-- * While we still use two-phase reduction, we use only a single
--   kernel, with the final workgroup to write a result (tracked via
--   an atomic counter) performing the final reduction as well.
--
-- * Instead of depending on storage layout transformations to handle
--   non-commutative reductions efficiently, we slide a
--   @groupsize@-sized window over the input, and perform a parallel
--   reduction for each window.  This sacrifices the notion of
--   efficient sequentialisation, but is sometimes faster and
--   definitely simpler and more predictable (and uses less auxiliary
--   storage).
--
-- For segmented reductions we use the approach from "Strategies for
-- Regular Segmented Reductions on GPU" (FHPC '17).  This involves
-- having two different strategies, and dynamically deciding which one
-- to use based on the number of segments and segment size. We use the
-- (static) @group_size@ to decide which of the following two
-- strategies to choose:
--
-- * Large: uses one or more groups to process a single segment. If
--   multiple groups are used per segment, the intermediate reduction
--   results must be recursively reduced, until there is only a single
--   value per segment.
--
--   Each thread /can/ read multiple elements, which will greatly
--   increase performance; however, if the reduction is
--   non-commutative we will have to use a less efficient traversal
--   (with interim group-wide reductions) to enable coalesced memory
--   accesses, just as in the non-segmented case.
--
-- * Small: is used to let each group process *multiple* segments
--   within a group. We will only use this approach when we can
--   process at least two segments within a single group. In those
--   cases, we would allocate a /whole/ group per segment with the
--   large strategy, but at most 50% of the threads in the group would
--   have any element to read, which becomes highly inefficient.
module Futhark.CodeGen.ImpGen.Kernels.SegRed
  ( compileSegRed,
    compileSegRed',
    DoSegBody,
  )
where

import Control.Monad.Except
import Data.List (genericLength, zip7)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.Error
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (chunks)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- | The maximum number of operators we support in a single SegRed.
-- This limit arises out of the static allocation of counters.
maxNumOps :: Int32
maxNumOps :: Int32
maxNumOps = Int32
10

-- | Code generation for the body of the SegRed, taking a continuation
-- for saving the results of the body.  The results should be
-- represented as a pairing of a t'SubExp' along with a list of
-- indexes into that 'SubExp' for reading the result.
type DoSegBody = ([(SubExp, [Imp.TExp Int64])] -> InKernelGen ()) -> InKernelGen ()

-- | Compile 'SegRed' instance to host-level code with calls to
-- various kernels.
compileSegRed ::
  Pattern KernelsMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  KernelBody KernelsMem ->
  CallKernelGen ()
compileSegRed :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegRed Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds KernelBody KernelsMem
body =
  Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ->
    Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
reds) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        let map_arrs :: [PatElemT LetDecMem]
map_arrs = Int -> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
reds) ([PatElemT LetDecMem] -> [PatElemT LetDecMem])
-> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a b. (a -> b) -> a -> b
$ PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
pat
        (PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem KernelsMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT LetDecMem]
map_arrs [KernelResult]
map_res

      [(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ([(SubExp, [TExp Int64])] -> InKernelGen ())
-> [(SubExp, [TExp Int64])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[TExp Int64]] -> [(SubExp, [TExp Int64])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) ([[TExp Int64]] -> [(SubExp, [TExp Int64])])
-> [[TExp Int64]] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [[TExp Int64]]
forall a. a -> [a]
repeat []

-- | Like 'compileSegRed', but where the body is a monadic action.
compileSegRed' ::
  Pattern KernelsMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  DoSegBody ->
  CallKernelGen ()
compileSegRed' :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body
  | [SegBinOp KernelsMem] -> Int32
forall i a. Num i => [a] -> i
genericLength [SegBinOp KernelsMem]
reds Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
> Int32
maxNumOps =
    String -> CallKernelGen ()
forall a. String -> a
compilerLimitationS (String -> CallKernelGen ()) -> String -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      String
"compileSegRed': at most " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int32 -> String
forall a. Show a => a -> String
show Int32
maxNumOps String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" reduction operators are supported."
  | [(VName
_, Constant (IntValue (Int64Value Int64
1))), (VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
    Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body
  | Bool
otherwise = do
    let group_size' :: TExp Int32
group_size' = SubExp -> TExp Int32
forall a. ToExp a => a -> TExp Int32
toInt32Exp (SubExp -> TExp Int32) -> SubExp -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
        segment_size :: TExp Int32
segment_size = SubExp -> TExp Int32
forall a. ToExp a => a -> TExp Int32
toInt32Exp (SubExp -> TExp Int32) -> SubExp -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last ([SubExp] -> SubExp) -> [SubExp] -> SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
        use_small_segments :: TPrimExp Bool ExpLeaf
use_small_segments = TExp Int32
segment_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2 TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
group_size'
    TPrimExp Bool ExpLeaf
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
      TPrimExp Bool ExpLeaf
use_small_segments
      (Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body)
      (Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern KernelsMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body)
  where
    num_groups :: Count NumGroups SubExp
num_groups = SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
    group_size :: Count GroupSize SubExp
group_size = SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

-- | Prepare intermediate arrays for the reduction.  Prim-typed
-- arguments go in local memory (so we need to do the allocation of
-- those arrays inside the kernel), while array-typed arguments go in
-- global memory.  Allocations for the former have already been
-- performed.  This policy is baked into how the allocations are done
-- in ExplicitAllocations.
intermediateArrays ::
  Count GroupSize SubExp ->
  SubExp ->
  SegBinOp KernelsMem ->
  InKernelGen [VName]
intermediateArrays :: Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays (Count SubExp
group_size) SubExp
num_threads (SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
nes Shape
_) = do
  let red_op_params :: [LParam KernelsMem]
red_op_params = Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op
      ([Param LetDecMem]
red_acc_params, [Param LetDecMem]
_) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [LParam KernelsMem]
[Param LetDecMem]
red_op_params
  [Param LetDecMem]
-> (Param LetDecMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LetDecMem]
red_acc_params ((Param LetDecMem -> ImpM KernelsMem KernelEnv KernelOp VName)
 -> InKernelGen [VName])
-> (Param LetDecMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall a b. (a -> b) -> a -> b
$ \Param LetDecMem
p ->
    case Param LetDecMem -> LetDecMem
forall dec. Param dec -> dec
paramDec Param LetDecMem
p of
      MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
        let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
        String
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray String
"red_arr" PrimType
pt Shape
shape' (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
          VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
      LetDecMem
_ -> do
        let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p
            shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
        String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"red_arr" PrimType
pt Shape
shape (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

-- | Arrays for storing group results.
--
-- The group-result arrays have an extra dimension (of size groupsize)
-- because they are also used for keeping vectorised accumulators for
-- first-stage reduction, if necessary.  When actually storing group
-- results, the first index is set to 0.
groupResultArrays ::
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  [SegBinOp KernelsMem] ->
  CallKernelGen [[VName]]
groupResultArrays :: Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays (Count SubExp
virt_num_groups) (Count SubExp
group_size) [SegBinOp KernelsMem]
reds =
  [SegBinOp KernelsMem]
-> (SegBinOp KernelsMem -> ImpM KernelsMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp KernelsMem]
reds ((SegBinOp KernelsMem -> ImpM KernelsMem HostEnv HostOp [VName])
 -> CallKernelGen [[VName]])
-> (SegBinOp KernelsMem -> ImpM KernelsMem HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda KernelsMem
lam [SubExp]
_ Shape
shape) ->
    [TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
    -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda KernelsMem -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda KernelsMem
lam) ((TypeBase Shape NoUniqueness
  -> ImpM KernelsMem HostEnv HostOp VName)
 -> ImpM KernelsMem HostEnv HostOp [VName])
-> (TypeBase Shape NoUniqueness
    -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
t -> do
      let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
          full_shape :: Shape
full_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size, SubExp
virt_num_groups] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t
          -- Move the groupsize dimension last to ensure coalesced
          -- memory access.
          perm :: [Int]
perm = [Int
1 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
full_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0]
      String
-> PrimType
-> Shape
-> Space
-> [Int]
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm String
"group_res_arr" PrimType
pt Shape
full_shape (String -> Space
Space String
"device") [Int]
perm

nonsegmentedReduction ::
  Pattern KernelsMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  DoSegBody ->
  CallKernelGen ()
nonsegmentedReduction :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern KernelsMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
      num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
      global_tid :: TExp Int32
global_tid = VName -> TExp Int32
Imp.vi32 (VName -> TExp Int32) -> VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
      w :: TExp Int64
w = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims'

  VName
counter <-
    String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
      [PrimValue] -> ArrayContents
Imp.ArrayValues ([PrimValue] -> ArrayContents) -> [PrimValue] -> ArrayContents
forall a b. (a -> b) -> a -> b
$ Int -> PrimValue -> [PrimValue]
forall a. Int -> a -> [a]
replicate (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps) (PrimValue -> [PrimValue]) -> PrimValue -> [PrimValue]
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
0

  [[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size [SegBinOp KernelsMem]
reds

  TV Int64
num_threads <-
    String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"num_threads" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
      Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed" Maybe Exp
forall a. Maybe a
Nothing

  String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_nonseg" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
    [[VName]]
reds_arrs <- (SegBinOp KernelsMem -> InKernelGen [VName])
-> [SegBinOp KernelsMem]
-> ImpM KernelsMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp KernelsMem]
reds

    -- Since this is the nonsegmented case, all outer segment IDs must
    -- necessarily be 0.
    [VName] -> (VName -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
gtids ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> TExp Int32 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
v (TExp Int32
0 :: Imp.TExp Int32)

    let num_elements :: Count Elements (TExp Int64)
num_elements = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements TExp Int64
w
        elems_per_thread :: Count Elements (TExp Int64)
elems_per_thread =
          Count Elements (TExp Int64)
num_elements
            Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants))

    [SegBinOpSlug]
slugs <-
      ((SegBinOp KernelsMem, [VName], [VName])
 -> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
        ( TExp Int32
-> TExp Int32
-> (SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug
            (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
            (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)
        )
        ([(SegBinOp KernelsMem, [VName], [VName])]
 -> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$ [SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [(SegBinOp KernelsMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
    [Lambda KernelsMem]
reds_op_renamed <-
      KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int32
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne
        KernelConstants
constants
        ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims')
        Count Elements (TExp Int64)
num_elements
        TExp Int32
global_tid
        Count Elements (TExp Int64)
elems_per_thread
        (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_threads)
        [SegBinOpSlug]
slugs
        DoSegBody
body

    let segred_pes :: [[PatElemT LetDecMem]]
segred_pes =
          [Int] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
reds) ([PatElemT LetDecMem] -> [[PatElemT LetDecMem]])
-> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a b. (a -> b) -> a -> b
$
            PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
segred_pat
    [(SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
  SegBinOpSlug, Lambda KernelsMem, Integer)]
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
     SegBinOpSlug, Lambda KernelsMem, Integer)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
      ( [SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LetDecMem]]
-> [SegBinOpSlug]
-> [Lambda KernelsMem]
-> [Integer]
-> [(SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
     SegBinOpSlug, Lambda KernelsMem, Integer)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7
          [SegBinOp KernelsMem]
reds
          [[VName]]
reds_arrs
          [[VName]]
reds_group_res_arrs
          [[PatElemT LetDecMem]]
segred_pes
          [SegBinOpSlug]
slugs
          [Lambda KernelsMem]
reds_op_renamed
          [Integer
0 ..]
      )
      (((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
   SegBinOpSlug, Lambda KernelsMem, Integer)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
     SegBinOpSlug, Lambda KernelsMem, Integer)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \( SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
nes Shape
_,
           [VName]
red_arrs,
           [VName]
group_res_arrs,
           [PatElemT LetDecMem]
pes,
           SegBinOpSlug
slug,
           Lambda KernelsMem
red_op_renamed,
           Integer
i
           ) -> do
          let ([Param LetDecMem]
red_x_params, [Param LetDecMem]
red_y_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op
          KernelConstants
-> [PatElem KernelsMem]
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> TExp Int64
-> TExp Int64
-> SegBinOpSlug
-> [LParam KernelsMem]
-> [LParam KernelsMem]
-> Lambda KernelsMem
-> [SubExp]
-> TExp Int32
-> VName
-> TExp Int32
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
            KernelConstants
constants
            [PatElem KernelsMem]
[PatElemT LetDecMem]
pes
            (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)
            TExp Int32
0
            [TExp Int64
0]
            TExp Int64
0
            (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int64
kernelNumGroups KernelConstants
constants)
            SegBinOpSlug
slug
            [LParam KernelsMem]
[Param LetDecMem]
red_x_params
            [LParam KernelsMem]
[Param LetDecMem]
red_y_params
            Lambda KernelsMem
red_op_renamed
            [SubExp]
nes
            TExp Int32
1
            VName
counter
            (Integer -> TExp Int32
forall a. Num a => Integer -> a
fromInteger Integer
i)
            VName
sync_arr
            [VName]
group_res_arrs
            [VName]
red_arrs

smallSegmentsReduction ::
  Pattern KernelsMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  DoSegBody ->
  CallKernelGen ()
smallSegmentsReduction :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
      segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims'

  -- Careful to avoid division by zero now.
  TExp Int64
segment_size_nonzero <-
    String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"segment_size_nonzero" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 TExp Int64
segment_size

  let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
  TV Int64
num_threads <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"num_threads" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
  let num_segments :: TExp Int64
num_segments = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims'
      segments_per_group :: TExp Int64
segments_per_group = Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size' TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
segment_size_nonzero
      required_groups :: TExp Int32
required_groups = TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64
num_segments TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
segments_per_group

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-small" Maybe Exp
forall a. Maybe a
Nothing
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
num_segments
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
segment_size
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segments_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
segments_per_group
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"required_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
required_groups

  String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_small" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [[VName]]
reds_arrs <- (SegBinOp KernelsMem -> InKernelGen [VName])
-> [SegBinOp KernelsMem]
-> ImpM KernelsMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_threads)) [SegBinOp KernelsMem]
reds

    -- We probably do not have enough actual workgroups to cover the
    -- entire iteration space.  Some groups thus have to perform double
    -- duty; we put an outer loop to accomplish this.
    SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TExp Int32
required_groups ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id' -> do
      -- Compute the 'n' input indices.  The outer 'n-1' correspond to
      -- the segment ID, and are computed from the group id.  The inner
      -- is computed from the local thread id, and may be out-of-bounds.
      let ltid :: TExp Int64
ltid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
          segment_index :: TExp Int64
segment_index =
            (TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
segment_size_nonzero)
              TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
segments_per_group)
          index_within_segment :: TExp Int64
index_within_segment = TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64
segment_size

      (VName -> TExp Int64 -> InKernelGen ())
-> [VName] -> [TExp Int64] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids) ([TExp Int64] -> InKernelGen ()) -> [TExp Int64] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims') TExp Int64
segment_index
      VName -> TExp Int64 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) TExp Int64
index_within_segment

      let out_of_bounds :: InKernelGen ()
out_of_bounds =
            [(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda KernelsMem
_ [SubExp]
nes Shape
_, [VName]
red_arrs) ->
              [(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
red_arrs [SubExp]
nes) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
ne) ->
                VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int64
ltid] SubExp
ne []

          in_bounds :: InKernelGen ()
in_bounds =
            DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])]
red_res ->
              String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save results to be reduced" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                let red_dests :: [(VName, [TExp Int64])]
red_dests = [VName] -> [[TExp Int64]] -> [(VName, [TExp Int64])]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs) ([[TExp Int64]] -> [(VName, [TExp Int64])])
-> [[TExp Int64]] -> [(VName, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [[TExp Int64]]
forall a. a -> [a]
repeat [TExp Int64
ltid]
                [((VName, [TExp Int64]), (SubExp, [TExp Int64]))]
-> (((VName, [TExp Int64]), (SubExp, [TExp Int64]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TExp Int64])]
-> [(SubExp, [TExp Int64])]
-> [((VName, [TExp Int64]), (SubExp, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [TExp Int64])]
red_dests [(SubExp, [TExp Int64])]
red_res) ((((VName, [TExp Int64]), (SubExp, [TExp Int64]))
  -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TExp Int64]), (SubExp, [TExp Int64]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
d, [TExp Int64]
d_is), (SubExp
res, [TExp Int64]
res_is)) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
d [TExp Int64]
d_is SubExp
res [TExp Int64]
res_is

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply map function if in bounds" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        TPrimExp Bool ExpLeaf
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
          ( TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int64
0
              TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(VName, SubExp)] -> TPrimExp Bool ExpLeaf
isActive ([(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [SubExp]
dims)
              TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
ltid TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segments_per_group
          )
          InKernelGen ()
in_bounds
          InKernelGen ()
out_of_bounds

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal -- Also implicitly barrier.
      let crossesSegment :: TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
crossesSegment TExp Int32
from TExp Int32
to =
            (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64
segment_size)
      TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int64
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform segmented scan to imitate reduction" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
_ Shape
_, [VName]
red_arrs) ->
            Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> TExp Int64
-> TExp Int64
-> Lambda KernelsMem
-> [VName]
-> InKernelGen ()
groupScan
              ((TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
crossesSegment)
              (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_threads)
              (TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segments_per_group)
              Lambda KernelsMem
red_op
              [VName]
red_arrs

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save final values of segments" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen
          ( TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segments_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
ltid TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
num_segments
              TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
ltid TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
segments_per_group
          )
          (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(PatElemT LetDecMem, VName)]
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [VName] -> [(PatElemT LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetDecMem]
segred_pes ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)) (((PatElemT LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, VName
arr) -> do
            -- Figure out which segment result this thread should write...
            let flat_segment_index :: TExp Int64
flat_segment_index =
                  TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segments_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
ltid
                gtids' :: [TExp Int64]
gtids' =
                  [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims') TExp Int64
flat_segment_index
            VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
              (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
              [TExp Int64]
gtids'
              (VName -> SubExp
Var VName
arr)
              [(TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
segment_size_nonzero TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]

      -- Finally another barrier, because we will be writing to the
      -- local memory array first thing in the next iteration.
      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

largeSegmentsReduction ::
  Pattern KernelsMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp KernelsMem] ->
  DoSegBody ->
  CallKernelGen ()
largeSegmentsReduction :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern KernelsMem
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
      num_segments :: TExp Int64
num_segments = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims'
      segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims'
      num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size

  (TExp Int64
groups_per_segment, Count Elements (TExp Int64)
elems_per_thread) <-
    TExp Int64
-> TExp Int64
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> CallKernelGen (TExp Int64, Count Elements (TExp Int64))
groupsPerSegmentAndElementsPerThread
      TExp Int64
segment_size
      TExp Int64
num_segments
      Count NumGroups (TExp Int64)
num_groups'
      Count GroupSize (TExp Int64)
group_size'
  TV Int64
virt_num_groups <-
    String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"virt_num_groups" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
      TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_segments

  TV Int64
num_threads <-
    String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"num_threads" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
      Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'

  TV Int64
threads_per_segment <-
    String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"threads_per_segment" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
      TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-large" Maybe Exp
forall a. Maybe a
Nothing
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
num_segments
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
segment_size
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"virt_num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
virt_num_groups
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count NumGroups (TExp Int64)
num_groups'
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"group_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count GroupSize (TExp Int64)
group_size'
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"elems_per_thread" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elems_per_thread
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"groups_per_segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
groups_per_segment

  [[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegBinOp KernelsMem]
-> CallKernelGen [[VName]]
groupResultArrays (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
virt_num_groups)) Count GroupSize SubExp
group_size [SegBinOp KernelsMem]
reds

  -- In principle we should have a counter for every segment.  Since
  -- the number of segments is a dynamic quantity, we would have to
  -- allocate and zero out an array here, which is expensive.
  -- However, we exploit the fact that the number of segments being
  -- reduced at any point in time is limited by the number of
  -- workgroups. If we bound the number of workgroups, we can get away
  -- with using that many counters.  FIXME: Is this limit checked
  -- anywhere?  There are other places in the compiler that will fail
  -- if the group count exceeds the maximum group size, which is at
  -- most 1024 anyway.
  let num_counters :: Int
num_counters = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024
  VName
counter <-
    String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
      Int -> ArrayContents
Imp.ArrayZeros Int
num_counters

  String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_large" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [[VName]]
reds_arrs <- (SegBinOp KernelsMem -> InKernelGen [VName])
-> [SegBinOp KernelsMem]
-> ImpM KernelsMem KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegBinOp KernelsMem -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_threads)) [SegBinOp KernelsMem]
reds
    VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

    -- We probably do not have enough actual workgroups to cover the
    -- entire iteration space.  Some groups thus have to perform double
    -- duty; we put an outer loop to accomplish this.
    SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
virt_num_groups)) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
      let segment_gtids :: [VName]
segment_gtids = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids
          w :: SubExp
w = [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
dims
          local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants

      TExp Int32
flat_segment_id <-
        String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"flat_segment_id" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
          TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
groups_per_segment

      TExp Int64
global_tid <-
        String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"global_tid" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
          (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size') TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid)
            TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size') TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groups_per_segment)

      let first_group_for_segment :: TExp Int64
first_group_for_segment = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groups_per_segment

      (VName -> TExp Int64 -> InKernelGen ())
-> [VName] -> [TExp Int64] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
segment_gtids ([TExp Int64] -> InKernelGen ()) -> [TExp Int64] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims') (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id
      VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int64
      let num_elements :: Count Elements (TExp Int64)
num_elements = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
w

      [SegBinOpSlug]
slugs <-
        ((SegBinOp KernelsMem, [VName], [VName])
 -> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug)
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TExp Int32
-> TExp Int32
-> (SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TExp Int32
local_tid TExp Int32
group_id) ([(SegBinOp KernelsMem, [VName], [VName])]
 -> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug])
-> [(SegBinOp KernelsMem, [VName], [VName])]
-> ImpM KernelsMem KernelEnv KernelOp [SegBinOpSlug]
forall a b. (a -> b) -> a -> b
$
          [SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [(SegBinOp KernelsMem, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp KernelsMem]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
      [Lambda KernelsMem]
reds_op_renamed <-
        KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int32
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne
          KernelConstants
constants
          ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims')
          Count Elements (TExp Int64)
num_elements
          (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
global_tid)
          Count Elements (TExp Int64)
elems_per_thread
          (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
threads_per_segment)
          [SegBinOpSlug]
slugs
          DoSegBody
body

      let segred_pes :: [[PatElemT LetDecMem]]
segred_pes =
            [Int] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
reds) ([PatElemT LetDecMem] -> [[PatElemT LetDecMem]])
-> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a b. (a -> b) -> a -> b
$
              PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
segred_pat

          multiple_groups_per_segment :: InKernelGen ()
multiple_groups_per_segment =
            [(SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
  SegBinOpSlug, Lambda KernelsMem, Integer)]
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
     SegBinOpSlug, Lambda KernelsMem, Integer)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
              ( [SegBinOp KernelsMem]
-> [[VName]]
-> [[VName]]
-> [[PatElemT LetDecMem]]
-> [SegBinOpSlug]
-> [Lambda KernelsMem]
-> [Integer]
-> [(SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
     SegBinOpSlug, Lambda KernelsMem, Integer)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7
                  [SegBinOp KernelsMem]
reds
                  [[VName]]
reds_arrs
                  [[VName]]
reds_group_res_arrs
                  [[PatElemT LetDecMem]]
segred_pes
                  [SegBinOpSlug]
slugs
                  [Lambda KernelsMem]
reds_op_renamed
                  [Integer
0 ..]
              )
              (((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
   SegBinOpSlug, Lambda KernelsMem, Integer)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName], [VName], [PatElemT LetDecMem],
     SegBinOpSlug, Lambda KernelsMem, Integer)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \( SegBinOp Commutativity
_ Lambda KernelsMem
red_op [SubExp]
nes Shape
_,
                   [VName]
red_arrs,
                   [VName]
group_res_arrs,
                   [PatElemT LetDecMem]
pes,
                   SegBinOpSlug
slug,
                   Lambda KernelsMem
red_op_renamed,
                   Integer
i
                   ) -> do
                  let ([Param LetDecMem]
red_x_params, [Param LetDecMem]
red_y_params) =
                        Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op
                  KernelConstants
-> [PatElem KernelsMem]
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> TExp Int64
-> TExp Int64
-> SegBinOpSlug
-> [LParam KernelsMem]
-> [LParam KernelsMem]
-> Lambda KernelsMem
-> [SubExp]
-> TExp Int32
-> VName
-> TExp Int32
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
                    KernelConstants
constants
                    [PatElem KernelsMem]
[PatElemT LetDecMem]
pes
                    TExp Int32
group_id
                    TExp Int32
flat_segment_id
                    ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
segment_gtids)
                    (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
first_group_for_segment)
                    TExp Int64
groups_per_segment
                    SegBinOpSlug
slug
                    [LParam KernelsMem]
[Param LetDecMem]
red_x_params
                    [LParam KernelsMem]
[Param LetDecMem]
red_y_params
                    Lambda KernelsMem
red_op_renamed
                    [SubExp]
nes
                    (Int -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters)
                    VName
counter
                    (Integer -> TExp Int32
forall a. Num a => Integer -> a
fromInteger Integer
i)
                    VName
sync_arr
                    [VName]
group_res_arrs
                    [VName]
red_arrs

          one_group_per_segment :: InKernelGen ()
one_group_per_segment =
            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"first thread in group saves final result to memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [(SegBinOpSlug, [PatElemT LetDecMem])]
-> ((SegBinOpSlug, [PatElemT LetDecMem]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[PatElemT LetDecMem]] -> [(SegBinOpSlug, [PatElemT LetDecMem])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElemT LetDecMem]]
segred_pes) (((SegBinOpSlug, [PatElemT LetDecMem]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOpSlug, [PatElemT LetDecMem]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElemT LetDecMem]
pes) ->
                TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  [(PatElemT LetDecMem, (VName, [TExp Int64]))]
-> ((PatElemT LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [(VName, [TExp Int64])]
-> [(PatElemT LetDecMem, (VName, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
pes (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)) (((PatElemT LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
v, (VName
acc, [TExp Int64]
acc_is)) ->
                    VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
v) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
segment_gtids) (VName -> SubExp
Var VName
acc) [TExp Int64]
acc_is

      TPrimExp Bool ExpLeaf
-> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) InKernelGen ()
one_group_per_segment InKernelGen ()
multiple_groups_per_segment

-- Careful to avoid division by zero here.  We have at least one group
-- per segment.
groupsPerSegmentAndElementsPerThread ::
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Count NumGroups (Imp.TExp Int64) ->
  Count GroupSize (Imp.TExp Int64) ->
  CallKernelGen
    ( Imp.TExp Int64,
      Imp.Count Imp.Elements (Imp.TExp Int64)
    )
groupsPerSegmentAndElementsPerThread :: TExp Int64
-> TExp Int64
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> CallKernelGen (TExp Int64, Count Elements (TExp Int64))
groupsPerSegmentAndElementsPerThread TExp Int64
segment_size TExp Int64
num_segments Count NumGroups (TExp Int64)
num_groups_hint Count GroupSize (TExp Int64)
group_size = do
  TExp Int64
groups_per_segment <-
    String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"groups_per_segment" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
      Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups_hint TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 TExp Int64
num_segments
  TExp Int64
elements_per_thread <-
    String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"elements_per_thread" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
      TExp Int64
segment_size TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groups_per_segment)
  (TExp Int64, Count Elements (TExp Int64))
-> CallKernelGen (TExp Int64, Count Elements (TExp Int64))
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Int64
groups_per_segment, TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements TExp Int64
elements_per_thread)

-- | A SegBinOp with auxiliary information.
data SegBinOpSlug = SegBinOpSlug
  { SegBinOpSlug -> SegBinOp KernelsMem
slugOp :: SegBinOp KernelsMem,
    -- | The arrays used for computing the intra-group reduction
    -- (either local or global memory).
    SegBinOpSlug -> [VName]
slugArrs :: [VName],
    -- | Places to store accumulator in stage 1 reduction.
    SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs :: [(VName, [Imp.TExp Int64])]
  }

slugBody :: SegBinOpSlug -> Body KernelsMem
slugBody :: SegBinOpSlug -> Body KernelsMem
slugBody = Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (Lambda KernelsMem -> Body KernelsMem)
-> (SegBinOpSlug -> Lambda KernelsMem)
-> SegBinOpSlug
-> Body KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp KernelsMem -> Lambda KernelsMem)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Lambda KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp

slugParams :: SegBinOpSlug -> [LParam KernelsMem]
slugParams :: SegBinOpSlug -> [LParam KernelsMem]
slugParams = Lambda KernelsMem -> [Param LetDecMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (Lambda KernelsMem -> [Param LetDecMem])
-> (SegBinOpSlug -> Lambda KernelsMem)
-> SegBinOpSlug
-> [Param LetDecMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp KernelsMem -> Lambda KernelsMem)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Lambda KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp

slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral (SegBinOp KernelsMem -> [SubExp])
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp

slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape (SegBinOp KernelsMem -> Shape)
-> (SegBinOpSlug -> SegBinOp KernelsMem) -> SegBinOpSlug -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp

slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ([Commutativity] -> Commutativity)
-> ([SegBinOpSlug] -> [Commutativity])
-> [SegBinOpSlug]
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOpSlug -> Commutativity)
-> [SegBinOpSlug] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map (SegBinOp KernelsMem -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm (SegBinOp KernelsMem -> Commutativity)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp)

accParams, nextParams :: SegBinOpSlug -> [LParam KernelsMem]
accParams :: SegBinOpSlug -> [LParam KernelsMem]
accParams SegBinOpSlug
slug = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam KernelsMem]
nextParams SegBinOpSlug
slug = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug

segBinOpSlug :: Imp.TExp Int32 -> Imp.TExp Int32 -> (SegBinOp KernelsMem, [VName], [VName]) -> InKernelGen SegBinOpSlug
segBinOpSlug :: TExp Int32
-> TExp Int32
-> (SegBinOp KernelsMem, [VName], [VName])
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
segBinOpSlug TExp Int32
local_tid TExp Int32
group_id (SegBinOp KernelsMem
op, [VName]
group_res_arrs, [VName]
param_arrs) =
  SegBinOp KernelsMem
-> [VName] -> [(VName, [TExp Int64])] -> SegBinOpSlug
SegBinOpSlug SegBinOp KernelsMem
op [VName]
group_res_arrs
    ([(VName, [TExp Int64])] -> SegBinOpSlug)
-> ImpM KernelsMem KernelEnv KernelOp [(VName, [TExp Int64])]
-> ImpM KernelsMem KernelEnv KernelOp SegBinOpSlug
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param LetDecMem
 -> VName
 -> ImpM KernelsMem KernelEnv KernelOp (VName, [TExp Int64]))
-> [Param LetDecMem]
-> [VName]
-> ImpM KernelsMem KernelEnv KernelOp [(VName, [TExp Int64])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param LetDecMem
-> VName
-> ImpM KernelsMem KernelEnv KernelOp (VName, [TExp Int64])
mkAcc (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
op)) [VName]
param_arrs
  where
    mkAcc :: Param LetDecMem
-> VName
-> ImpM KernelsMem KernelEnv KernelOp (VName, [TExp Int64])
mkAcc Param LetDecMem
p VName
param_arr
      | Prim PrimType
t <- Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p,
        Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp KernelsMem
op) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = do
        TV Any
acc <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim (VName -> String
baseString (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_acc") PrimType
t
        (VName, [TExp Int64])
-> ImpM KernelsMem KernelEnv KernelOp (VName, [TExp Int64])
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc, [])
      | Bool
otherwise =
        (VName, [TExp Int64])
-> ImpM KernelsMem KernelEnv KernelOp (VName, [TExp Int64])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
param_arr, [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid, TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id])

reductionStageZero ::
  KernelConstants ->
  [(VName, Imp.TExp Int64)] ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.TExp Int32 ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  VName ->
  [SegBinOpSlug] ->
  DoSegBody ->
  InKernelGen ([Lambda KernelsMem], InKernelGen ())
reductionStageZero :: KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int32
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TExp Int64)]
ispace Count Elements (TExp Int64)
num_elements TExp Int32
global_tid Count Elements (TExp Int64)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
  let ([VName]
gtids, [TExp Int64]
_dims) = [(VName, TExp Int64)] -> ([VName], [TExp Int64])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, TExp Int64)]
ispace
      gtid :: TV Int64
gtid = VName -> PrimType -> TV Int64
forall t. VName -> PrimType -> TV t
mkTV ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int64
      local_tid :: TExp Int64
local_tid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants

  -- Figure out how many elements this thread should process.
  TV Int64
chunk_size <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"chunk_size" PrimType
int64
  let ordering :: SplitOrdering
ordering = case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
        Commutativity
Commutative -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
threads_per_segment
        Commutativity
Noncommutative -> SplitOrdering
SplitContiguous
  SplitOrdering
-> TExp Int64
-> Count Elements (TExp Int64)
-> Count Elements (TExp Int64)
-> TV Int64
-> InKernelGen ()
forall lore r op.
SplitOrdering
-> TExp Int64
-> Count Elements (TExp Int64)
-> Count Elements (TExp Int64)
-> TV Int64
-> ImpM lore r op ()
computeThreadChunkSize SplitOrdering
ordering (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
global_tid) Count Elements (TExp Int64)
elems_per_thread Count Elements (TExp Int64)
num_elements TV Int64
chunk_size

  Maybe (Exp KernelsMem) -> Scope KernelsMem -> InKernelGen ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp KernelsMem)
forall a. Maybe a
Nothing (Scope KernelsMem -> InKernelGen ())
-> Scope KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> Scope KernelsMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param LetDecMem] -> Scope KernelsMem)
-> [Param LetDecMem] -> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ (SegBinOpSlug -> [Param LetDecMem])
-> [SegBinOpSlug] -> [Param LetDecMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam KernelsMem]
SegBinOpSlug -> [Param LetDecMem]
slugParams [SegBinOpSlug]
slugs

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"neutral-initialise the accumulators" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    [SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
      [((VName, [TExp Int64]), SubExp)]
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TExp Int64])]
-> [SubExp] -> [((VName, [TExp Int64]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [TExp Int64]), SubExp) -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TExp Int64]
acc_is), SubExp
ne) ->
        Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
ne []

  [Lambda KernelsMem]
slugs_op_renamed <- (SegBinOpSlug
 -> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem))
-> [SegBinOpSlug] -> InKernelGen [Lambda KernelsMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Lambda KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (Lambda KernelsMem
 -> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem))
-> (SegBinOpSlug -> Lambda KernelsMem)
-> SegBinOpSlug
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda (SegBinOp KernelsMem -> Lambda KernelsMem)
-> (SegBinOpSlug -> SegBinOp KernelsMem)
-> SegBinOpSlug
-> Lambda KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp KernelsMem
slugOp) [SegBinOpSlug]
slugs

  let doTheReduction :: InKernelGen ()
doTheReduction =
        [(Lambda KernelsMem, SegBinOpSlug)]
-> ((Lambda KernelsMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Lambda KernelsMem]
-> [SegBinOpSlug] -> [(Lambda KernelsMem, SegBinOpSlug)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Lambda KernelsMem]
slugs_op_renamed [SegBinOpSlug]
slugs) (((Lambda KernelsMem, SegBinOpSlug) -> InKernelGen ())
 -> InKernelGen ())
-> ((Lambda KernelsMem, SegBinOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Lambda KernelsMem
slug_op_renamed, SegBinOpSlug
slug) ->
          Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"to reduce current chunk, first store our result in memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              [(Param LetDecMem, (VName, [TExp Int64]))]
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(VName, [TExp Int64])]
-> [(Param LetDecMem, (VName, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)) (((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (VName
acc, [TExp Int64]
acc_is)) ->
                VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
acc) ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

              [(VName, Param LetDecMem)]
-> ((VName, Param LetDecMem) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Param LetDecMem] -> [(VName, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug) (SegBinOpSlug -> [LParam KernelsMem]
slugParams SegBinOpSlug
slug)) (((VName, Param LetDecMem) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, Param LetDecMem) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LetDecMem
p) ->
                Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int64
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []

            KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal -- Also implicitly barrier.
            TExp Int32 -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)) Lambda KernelsMem
slug_op_renamed (SegBinOpSlug -> [VName]
slugArrs SegBinOpSlug
slug)

            KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread saves the result in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int64
local_tid TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [((VName, [TExp Int64]), Param LetDecMem)]
-> (((VName, [TExp Int64]), Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TExp Int64])]
-> [Param LetDecMem] -> [((VName, [TExp Int64]), Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug) (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
slug_op_renamed)) ((((VName, [TExp Int64]), Param LetDecMem) -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TExp Int64]), Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TExp Int64]
acc_is), Param LetDecMem
p) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []

  -- If this is a non-commutative reduction, each thread must run the
  -- loop the same number of iterations, because we will be performing
  -- a group-wide reduction in there.
  let comm :: Commutativity
comm = [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs
      (TExp Int64
bound, InKernelGen () -> InKernelGen ()
check_bounds) =
        case Commutativity
comm of
          Commutativity
Commutative -> (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_size, InKernelGen () -> InKernelGen ()
forall a. a -> a
id)
          Commutativity
Noncommutative ->
            ( Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elems_per_thread,
              TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
gtid TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
num_elements)
            )

  String
-> TExp Int64 -> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
bound ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    TV Int64
gtid
      TV Int64 -> TExp Int64 -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- case Commutativity
comm of
        Commutativity
Commutative ->
          TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
global_tid
            TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ VName -> TExp Int64
Imp.vi64 VName
threads_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
i
        Commutativity
Noncommutative ->
          let index_in_segment :: TExp Int32
index_in_segment = TExp Int32
global_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
           in TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
local_tid
                TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
index_in_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elems_per_thread TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i)
                TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)

    InKernelGen () -> InKernelGen ()
check_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply map function" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])]
all_red_res -> do
          let slugs_res :: [[(SubExp, [TExp Int64])]]
slugs_res = [Int] -> [(SubExp, [TExp Int64])] -> [[(SubExp, [TExp Int64])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOpSlug -> Int) -> [SegBinOpSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOpSlug -> [SubExp]) -> SegBinOpSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [TExp Int64])]
all_red_res

          [(SegBinOpSlug, [(SubExp, [TExp Int64])])]
-> ((SegBinOpSlug, [(SubExp, [TExp Int64])]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOpSlug]
-> [[(SubExp, [TExp Int64])]]
-> [(SegBinOpSlug, [(SubExp, [TExp Int64])])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[(SubExp, [TExp Int64])]]
slugs_res) (((SegBinOpSlug, [(SubExp, [TExp Int64])]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOpSlug, [(SubExp, [TExp Int64])]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [(SubExp, [TExp Int64])]
red_res) ->
            Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
              String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [(Param LetDecMem, (VName, [TExp Int64]))]
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(VName, [TExp Int64])]
-> [(Param LetDecMem, (VName, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)) (((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (VName
acc, [TExp Int64]
acc_is)) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
acc) ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
              String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load new values" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [(Param LetDecMem, (SubExp, [TExp Int64]))]
-> ((Param LetDecMem, (SubExp, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(SubExp, [TExp Int64])]
-> [(Param LetDecMem, (SubExp, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
nextParams SegBinOpSlug
slug) [(SubExp, [TExp Int64])]
red_res) (((Param LetDecMem, (SubExp, [TExp Int64])) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, (SubExp, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
res, [TExp Int64]
res_is)) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
res ([TExp Int64]
res_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
              String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply reduction operator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body KernelsMem -> Stms KernelsMem)
-> Body KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"store in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                    [((VName, [TExp Int64]), SubExp)]
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_
                      ( [(VName, [TExp Int64])]
-> [SubExp] -> [((VName, [TExp Int64]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip
                          (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)
                          (Body KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body KernelsMem -> [SubExp]) -> Body KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug)
                      )
                      ((((VName, [TExp Int64]), SubExp) -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TExp Int64]
acc_is), SubExp
se) ->
                        VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se []

    case Commutativity
comm of
      Commutativity
Noncommutative -> do
        InKernelGen ()
doTheReduction
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread keeps accumulator; others reset to neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          let reset_to_neutral :: InKernelGen ()
reset_to_neutral =
                [SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
                  [((VName, [TExp Int64]), SubExp)]
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [TExp Int64])]
-> [SubExp] -> [((VName, [TExp Int64]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) ((((VName, [TExp Int64]), SubExp) -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [TExp Int64]), SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [TExp Int64]
acc_is), SubExp
ne) ->
                    Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
                      VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
acc ([TExp Int64]
acc_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
ne []
          TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sUnless (TExp Int64
local_tid TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) InKernelGen ()
reset_to_neutral
      Commutativity
_ -> () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  ([Lambda KernelsMem], InKernelGen ())
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([Lambda KernelsMem]
slugs_op_renamed, InKernelGen ()
doTheReduction)

reductionStageOne ::
  KernelConstants ->
  [(VName, Imp.TExp Int64)] ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.TExp Int32 ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  VName ->
  [SegBinOpSlug] ->
  DoSegBody ->
  InKernelGen [Lambda KernelsMem]
reductionStageOne :: KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int32
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda KernelsMem]
reductionStageOne KernelConstants
constants [(VName, TExp Int64)]
ispace Count Elements (TExp Int64)
num_elements TExp Int32
global_tid Count Elements (TExp Int64)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body = do
  ([Lambda KernelsMem]
slugs_op_renamed, InKernelGen ()
doTheReduction) <-
    KernelConstants
-> [(VName, TExp Int64)]
-> Count Elements (TExp Int64)
-> TExp Int32
-> Count Elements (TExp Int64)
-> VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda KernelsMem], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, TExp Int64)]
ispace Count Elements (TExp Int64)
num_elements TExp Int32
global_tid Count Elements (TExp Int64)
elems_per_thread VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body

  case [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs of
    Commutativity
Noncommutative ->
      [SegBinOpSlug]
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs ((SegBinOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
        [(Param LetDecMem, (VName, [TExp Int64]))]
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(VName, [TExp Int64])]
-> [(Param LetDecMem, (VName, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam KernelsMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)) (((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (VName
acc, [TExp Int64]
acc_is)) ->
          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
acc) [TExp Int64]
acc_is
    Commutativity
_ -> InKernelGen ()
doTheReduction

  [Lambda KernelsMem] -> InKernelGen [Lambda KernelsMem]
forall (m :: * -> *) a. Monad m => a -> m a
return [Lambda KernelsMem]
slugs_op_renamed

reductionStageTwo ::
  KernelConstants ->
  [PatElem KernelsMem] ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  [Imp.TExp Int64] ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  SegBinOpSlug ->
  [LParam KernelsMem] ->
  [LParam KernelsMem] ->
  Lambda KernelsMem ->
  [SubExp] ->
  Imp.TExp Int32 ->
  VName ->
  Imp.TExp Int32 ->
  VName ->
  [VName] ->
  [VName] ->
  InKernelGen ()
reductionStageTwo :: KernelConstants
-> [PatElem KernelsMem]
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> TExp Int64
-> TExp Int64
-> SegBinOpSlug
-> [LParam KernelsMem]
-> [LParam KernelsMem]
-> Lambda KernelsMem
-> [SubExp]
-> TExp Int32
-> VName
-> TExp Int32
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo
  KernelConstants
constants
  [PatElem KernelsMem]
segred_pes
  TExp Int32
group_id
  TExp Int32
flat_segment_id
  [TExp Int64]
segment_gtids
  TExp Int64
first_group_for_segment
  TExp Int64
groups_per_segment
  SegBinOpSlug
slug
  [LParam KernelsMem]
red_x_params
  [LParam KernelsMem]
red_y_params
  Lambda KernelsMem
red_op_renamed
  [SubExp]
nes
  TExp Int32
num_counters
  VName
counter
  TExp Int32
counter_i
  VName
sync_arr
  [VName]
group_res_arrs
  [VName]
red_arrs = do
    let local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
        group_size :: TExp Int64
group_size = KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
    TV Int64
old_counter <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old_counter" PrimType
int32
    (VName
counter_mem, Space
_, Count Elements (TExp Int64)
counter_offset) <-
      VName
-> [TExp Int64]
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray
        VName
counter
        [ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
            TExp Int32
counter_i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
num_counters
              TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
flat_segment_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
num_counters
        ]
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"first thread in group saves group result to global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        [(VName, (VName, [TExp Int64]))]
-> ((VName, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Int
-> [(VName, (VName, [TExp Int64]))]
-> [(VName, (VName, [TExp Int64]))]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([(VName, (VName, [TExp Int64]))]
 -> [(VName, (VName, [TExp Int64]))])
-> [(VName, (VName, [TExp Int64]))]
-> [(VName, (VName, [TExp Int64]))]
forall a b. (a -> b) -> a -> b
$ [VName]
-> [(VName, [TExp Int64])] -> [(VName, (VName, [TExp Int64]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
group_res_arrs (SegBinOpSlug -> [(VName, [TExp Int64])]
slugAccs SegBinOpSlug
slug)) (((VName, (VName, [TExp Int64])) -> InKernelGen ())
 -> InKernelGen ())
-> ((VName, (VName, [TExp Int64])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, (VName
acc, [TExp Int64]
acc_is)) ->
          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
v [TExp Int64
0, TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id] (VName -> SubExp
Var VName
acc) [TExp Int64]
acc_is
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
        -- Increment the counter, thus stating that our result is
        -- available.
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
            IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd
              IntType
Int32
              (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
old_counter)
              VName
counter_mem
              Count Elements (TExp Int64)
counter_offset
              (Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32
1 :: Imp.TExp Int32)
        -- Now check if we were the last group to write our result.  If
        -- so, it is our responsibility to produce the final result.
        VName -> [TExp Int64] -> Exp -> InKernelGen ()
forall lore r op. VName -> [TExp Int64] -> Exp -> ImpM lore r op ()
sWrite VName
sync_arr [TExp Int64
0] (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool ExpLeaf -> Exp) -> TPrimExp Bool ExpLeaf -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
old_counter TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1

    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

    TV Bool
is_last_group <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Bool)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"is_last_group" PrimType
Bool
    VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Bool -> VName
forall t. TV t -> VName
tvVar TV Bool
is_last_group) [] (VName -> SubExp
Var VName
sync_arr) [TExp Int64
0]
    TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TV Bool -> TPrimExp Bool ExpLeaf
forall t. TV t -> TExp t
tvExp TV Bool
is_last_group) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      -- The final group has written its result (and it was
      -- us!), so read in all the group results and perform the
      -- final stage of the reduction.  But first, we reset the
      -- counter so it is ready for next time.  This is done
      -- with an atomic to avoid warnings about write/write
      -- races in oclgrind.
      TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
            IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32 (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
old_counter) VName
counter_mem Count Elements (TExp Int64)
counter_offset (Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$
              TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64
forall a. Num a => a -> a
negate TExp Int64
groups_per_segment

      Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
        -- There is no guarantee that the number of workgroups for the
        -- segment is less than the workgroup size, so each thread may
        -- have to read multiple elements.  We do this in a sequential
        -- way that may induce non-coalesced accesses, but the total
        -- number of accesses should be tiny here.
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"read in the per-group-results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          TExp Int64
read_per_thread <-
            String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"read_per_thread" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
              TExp Int64
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
group_size

          [(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LetDecMem]
red_x_params [SubExp]
nes) (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
ne) ->
            VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
ne []

          String
-> TExp Int64 -> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
read_per_thread ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
            TExp Int64
group_res_id <-
              String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"group_res_id" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
read_per_thread TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i
            TExp Int64
index_of_group_res <-
              String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"index_of_group_res" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                TExp Int64
first_group_for_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
group_res_id

            TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int64
group_res_id TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
groups_per_segment) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LetDecMem]
red_y_params [VName]
group_res_arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                \(Param LetDecMem
p, VName
group_res_arr) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                    (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p)
                    []
                    (VName -> SubExp
Var VName
group_res_arr)
                    ([TExp Int64
0, TExp Int64
index_of_group_res] [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)

              Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body KernelsMem -> Stms KernelsMem)
-> Body KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LetDecMem]
red_x_params (Body KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body KernelsMem -> [SubExp]) -> Body KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body KernelsMem
slugBody SegBinOpSlug
slug)) (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
se) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
se []

        [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam KernelsMem]
[Param LetDecMem]
red_x_params [VName]
red_arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, VName
arr) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LetDecMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []

        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"reduce the per-group results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          TExp Int32 -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
group_size) Lambda KernelsMem
red_op_renamed [VName]
red_arrs

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"and back to memory with the final result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            TPrimExp Bool ExpLeaf -> InKernelGen () -> InKernelGen ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [(PatElemT LetDecMem, Param LetDecMem)]
-> ((PatElemT LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [Param LetDecMem] -> [(PatElemT LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetDecMem]
segred_pes ([Param LetDecMem] -> [(PatElemT LetDecMem, Param LetDecMem)])
-> [Param LetDecMem] -> [(PatElemT LetDecMem, Param LetDecMem)]
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
red_op_renamed) (((PatElemT LetDecMem, Param LetDecMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, Param LetDecMem
p) ->
                VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
                  (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
                  ([TExp Int64]
segment_gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
                  (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p)
                  []