{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
-- | We generate code for non-segmented/single-segment SegRed using
-- the basic approach outlined in the paper "Design and GPGPU
-- Performance of Futhark’s Redomap Construct" (ARRAY '16).  The main
-- deviations are:
--
-- * While we still use two-phase reduction, we use only a single
--   kernel, with the final workgroup to write a result (tracked via
--   an atomic counter) performing the final reduction as well.
--
-- * Instead of depending on storage layout transformations to handle
--   non-commutative reductions efficiently, we slide a
--   'groupsize'-sized window over the input, and perform a parallel
--   reduction for each window.  This sacrifices the notion of
--   efficient sequentialisation, but is sometimes faster and
--   definitely simpler and more predictable (and uses less auxiliary
--   storage).
--
-- For segmented reductions we use the approach from "Strategies for
-- Regular Segmented Reductions on GPU" (FHPC '17).  This involves
-- having two different strategies, and dynamically deciding which one
-- to use based on the number of segments and segment size. We use the
-- (static) @group_size@ to decide which of the following two
-- strategies to choose:
--
-- * Large: uses one or more groups to process a single segment. If
--   multiple groups are used per segment, the intermediate reduction
--   results must be recursively reduced, until there is only a single
--   value per segment.
--
--   Each thread /can/ read multiple elements, which will greatly
--   increase performance; however, if the reduction is
--   non-commutative we will have to use a less efficient traversal
--   (with interim group-wide reductions) to enable coalesced memory
--   accesses, just as in the non-segmented case.
--
-- * Small: is used to let each group process *multiple* segments
--   within a group. We will only use this approach when we can
--   process at least two segments within a single group. In those
--   cases, we would allocate a /whole/ group per segment with the
--   large strategy, but at most 50% of the threads in the group would
--   have any element to read, which becomes highly inefficient.
module Futhark.CodeGen.ImpGen.Kernels.SegRed
  ( compileSegRed
  , compileSegRed'
  , DoSegBody
  )
  where

import Control.Monad.Except
import Data.Maybe
import Data.List (genericLength, zip4, zip7)

import Prelude hiding (quot, rem)

import Futhark.Error
import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Util (chunks)
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)

-- | The maximum number of operators we support in a single SegRed.
-- This limit arises out of the static allocation of counters.
maxNumOps :: Int32
maxNumOps :: Int32
maxNumOps = Int32
10

type DoSegBody = ([(SubExp, [Imp.Exp])] -> InKernelGen ()) -> InKernelGen ()

-- | Compile 'SegRed' instance to host-level code with calls to
-- various kernels.
compileSegRed :: Pattern ExplicitMemory
              -> SegLevel -> SegSpace
              -> [SegRedOp ExplicitMemory]
              -> KernelBody ExplicitMemory
              -> CallKernelGen ()
compileSegRed :: Pattern ExplicitMemory
-> SegLevel
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> KernelBody ExplicitMemory
-> CallKernelGen ()
compileSegRed Pattern ExplicitMemory
pat SegLevel
lvl SegSpace
space [SegRedOp ExplicitMemory]
reds KernelBody ExplicitMemory
body =
  Pattern ExplicitMemory
-> SegLevel
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern ExplicitMemory
pat SegLevel
lvl SegSpace
space [SegRedOp ExplicitMemory]
reds (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [Exp])] -> InKernelGen ()
red_cont ->
  Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegRedOp ExplicitMemory] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp ExplicitMemory]
reds) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    let map_arrs :: [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_arrs = Int
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([SegRedOp ExplicitMemory] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp ExplicitMemory]
reds) ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
 -> [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
    (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> KernelResult -> InKernelGen ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElem ExplicitMemory -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_arrs [KernelResult]
map_res

  [(SubExp, [Exp])] -> InKernelGen ()
red_cont ([(SubExp, [Exp])] -> InKernelGen ())
-> [(SubExp, [Exp])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[Exp]] -> [(SubExp, [Exp])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) ([[Exp]] -> [(SubExp, [Exp])]) -> [[Exp]] -> [(SubExp, [Exp])]
forall a b. (a -> b) -> a -> b
$ [Exp] -> [[Exp]]
forall a. a -> [a]
repeat []

-- | Like 'compileSegRed', but where the body is a monadic action.
compileSegRed' :: Pattern ExplicitMemory
               -> SegLevel -> SegSpace
               -> [SegRedOp ExplicitMemory]
               -> DoSegBody
               -> CallKernelGen ()
compileSegRed' :: Pattern ExplicitMemory
-> SegLevel
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pattern ExplicitMemory
pat SegLevel
lvl SegSpace
space [SegRedOp ExplicitMemory]
reds DoSegBody
body
  | [SegRedOp ExplicitMemory] -> Int32
forall i a. Num i => [a] -> i
genericLength [SegRedOp ExplicitMemory]
reds Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
> Int32
maxNumOps =
      String -> CallKernelGen ()
forall a. String -> a
compilerLimitationS (String -> CallKernelGen ()) -> String -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      String
"compileSegRed': at most " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int32 -> String
forall a. Show a => a -> String
show Int32
maxNumOps String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" reduction operators are supported."
  | [(VName
_, Constant (IntValue (Int32Value Int32
1))), (VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
      Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern ExplicitMemory
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegRedOp ExplicitMemory]
reds DoSegBody
body
  | Bool
otherwise = do
      Exp
group_size' <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
      Exp
segment_size <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last ([SubExp] -> SubExp) -> [SubExp] -> SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
      let use_small_segments :: Exp
use_small_segments = Exp
segment_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
group_size'
      Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
use_small_segments
        (Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction Pattern ExplicitMemory
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegRedOp ExplicitMemory]
reds DoSegBody
body)
        (Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern ExplicitMemory
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegRedOp ExplicitMemory]
reds DoSegBody
body)
  where num_groups :: Count NumGroups SubExp
num_groups = SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
        group_size :: Count GroupSize SubExp
group_size = SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

-- | Prepare intermediate arrays for the reduction.  Prim-typed
-- arguments go in local memory (so we need to do the allocation of
-- those arrays inside the kernel), while array-typed arguments go in
-- global memory.  Allocations for the former have already been
-- performed.  This policy is baked into how the allocations are done
-- in ExplicitAllocations.
intermediateArrays :: Count GroupSize SubExp -> SubExp
                   -> SegRedOp ExplicitMemory
                   -> InKernelGen [VName]
intermediateArrays :: Count GroupSize SubExp
-> SubExp -> SegRedOp ExplicitMemory -> InKernelGen [VName]
intermediateArrays (Count SubExp
group_size) SubExp
num_threads (SegRedOp Commutativity
_ Lambda ExplicitMemory
red_op [SubExp]
nes Shape
_) = do
  let red_op_params :: [LParam ExplicitMemory]
red_op_params = Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
red_op
      ([Param (MemInfo SubExp NoUniqueness MemBind)]
red_acc_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
_) = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
red_op_params
  [Param (MemInfo SubExp NoUniqueness MemBind)]
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (MemInfo SubExp NoUniqueness MemBind)]
red_acc_params ((Param (MemInfo SubExp NoUniqueness MemBind)
  -> ImpM ExplicitMemory KernelEnv KernelOp VName)
 -> InKernelGen [VName])
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> InKernelGen [VName]
forall a b. (a -> b) -> a -> b
$ \Param (MemInfo SubExp NoUniqueness MemBind)
p ->
    case Param (MemInfo SubExp NoUniqueness MemBind)
-> MemInfo SubExp NoUniqueness MemBind
forall attr. Param attr -> attr
paramAttr Param (MemInfo SubExp NoUniqueness MemBind)
p of
      MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
        let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
        String
-> PrimType
-> Shape
-> MemBind
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray String
"red_arr" PrimType
pt Shape
shape' (MemBind -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> MemBind -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
          VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
      MemInfo SubExp NoUniqueness MemBind
_ -> do
        let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase Shape NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p
            shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
        String
-> PrimType
-> Shape
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"red_arr" PrimType
pt Shape
shape (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

-- | Arrays for storing group results.
--
-- The group-result arrays have an extra dimension (of size groupsize)
-- because they are also used for keeping vectorised accumulators for
-- first-stage reduction, if necessary.  When actually storing group
-- results, the first index is set to 0.
groupResultArrays :: Count NumGroups SubExp -> Count GroupSize SubExp
                  -> [SegRedOp ExplicitMemory]
                  -> CallKernelGen [[VName]]
groupResultArrays :: Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegRedOp ExplicitMemory]
-> CallKernelGen [[VName]]
groupResultArrays (Count SubExp
virt_num_groups) (Count SubExp
group_size) [SegRedOp ExplicitMemory]
reds =
  [SegRedOp ExplicitMemory]
-> (SegRedOp ExplicitMemory
    -> ImpM ExplicitMemory HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegRedOp ExplicitMemory]
reds ((SegRedOp ExplicitMemory
  -> ImpM ExplicitMemory HostEnv HostOp [VName])
 -> CallKernelGen [[VName]])
-> (SegRedOp ExplicitMemory
    -> ImpM ExplicitMemory HostEnv HostOp [VName])
-> CallKernelGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegRedOp Commutativity
_ Lambda ExplicitMemory
lam [SubExp]
_ Shape
shape) ->
    [TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
    -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda ExplicitMemory -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda ExplicitMemory
lam) ((TypeBase Shape NoUniqueness
  -> ImpM ExplicitMemory HostEnv HostOp VName)
 -> ImpM ExplicitMemory HostEnv HostOp [VName])
-> (TypeBase Shape NoUniqueness
    -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
t -> do
    let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t
        full_shape :: Shape
full_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size, SubExp
virt_num_groups] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t
        -- Move the groupsize dimension last to ensure coalesced
        -- memory access.
        perm :: [Int]
perm = [Int
1..Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
full_shapeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0]
    String
-> PrimType
-> Shape
-> Space
-> [Int]
-> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op.
String
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm String
"group_res_arr" PrimType
pt Shape
full_shape (String -> Space
Space String
"device") [Int]
perm

nonsegmentedReduction :: Pattern ExplicitMemory
                      -> Count NumGroups SubExp -> Count GroupSize SubExp -> SegSpace
                      -> [SegRedOp ExplicitMemory]
                      -> DoSegBody
                      -> CallKernelGen ()
nonsegmentedReduction :: Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction Pattern ExplicitMemory
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegRedOp ExplicitMemory]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [Exp]
dims' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims

  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size

  let global_tid :: Exp
global_tid = VName -> Exp
Imp.vi32 (VName -> Exp) -> VName -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
      w :: Exp
w = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'

  VName
counter <-
    String
-> Space
-> PrimType
-> ArrayContents
-> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ArrayContents -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
    [PrimValue] -> ArrayContents
Imp.ArrayValues ([PrimValue] -> ArrayContents) -> [PrimValue] -> ArrayContents
forall a b. (a -> b) -> a -> b
$ Int -> PrimValue -> [PrimValue]
forall a. Int -> a -> [a]
replicate (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps) (PrimValue -> [PrimValue]) -> PrimValue -> [PrimValue]
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
0

  [[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegRedOp ExplicitMemory]
-> CallKernelGen [[VName]]
groupResultArrays Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size [SegRedOp ExplicitMemory]
reds

  VName
num_threads <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_threads" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed" Maybe Exp
forall a. Maybe a
Nothing

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_nonseg" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
    [[VName]]
reds_arrs <- (SegRedOp ExplicitMemory -> InKernelGen [VName])
-> [SegRedOp ExplicitMemory]
-> ImpM ExplicitMemory KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegRedOp ExplicitMemory -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
num_threads)) [SegRedOp ExplicitMemory]
reds

    -- Since this is the nonsegmented case, all outer segment IDs must
    -- necessarily be 0.
    [VName] -> (VName -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
gtids ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
v Exp
0

    let num_elements :: Count Elements Exp
num_elements = Exp -> Count Elements Exp
Imp.elements Exp
w
    let elems_per_thread :: Count Elements Exp
elems_per_thread = Count Elements Exp
num_elements Count Elements Exp -> Count Elements Exp -> Count Elements Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Exp -> Count Elements Exp
Imp.elements (KernelConstants -> Exp
kernelNumThreads KernelConstants
constants)

    [SegRedOpSlug]
slugs <- ((SegRedOp ExplicitMemory, [VName], [VName])
 -> ImpM ExplicitMemory KernelEnv KernelOp SegRedOpSlug)
-> [(SegRedOp ExplicitMemory, [VName], [VName])]
-> ImpM ExplicitMemory KernelEnv KernelOp [SegRedOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Exp
-> Exp
-> (SegRedOp ExplicitMemory, [VName], [VName])
-> ImpM ExplicitMemory KernelEnv KernelOp SegRedOpSlug
segRedOpSlug (KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants) (KernelConstants -> Exp
kernelGroupId KernelConstants
constants)) ([(SegRedOp ExplicitMemory, [VName], [VName])]
 -> ImpM ExplicitMemory KernelEnv KernelOp [SegRedOpSlug])
-> [(SegRedOp ExplicitMemory, [VName], [VName])]
-> ImpM ExplicitMemory KernelEnv KernelOp [SegRedOpSlug]
forall a b. (a -> b) -> a -> b
$
             [SegRedOp ExplicitMemory]
-> [[VName]]
-> [[VName]]
-> [(SegRedOp ExplicitMemory, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegRedOp ExplicitMemory]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
    [Lambda ExplicitMemory]
reds_op_renamed <-
      KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegRedOpSlug]
-> DoSegBody
-> InKernelGen [Lambda ExplicitMemory]
reductionStageOne KernelConstants
constants ([VName] -> [Exp] -> [(VName, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [Exp]
dims') Count Elements Exp
num_elements
      Exp
global_tid Count Elements Exp
elems_per_thread VName
num_threads
      [SegRedOpSlug]
slugs DoSegBody
body

    let segred_pes :: [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
segred_pes = [Int]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegRedOp ExplicitMemory -> Int)
-> [SegRedOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegRedOp ExplicitMemory -> [SubExp])
-> SegRedOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp ExplicitMemory -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral) [SegRedOp ExplicitMemory]
reds) ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
 -> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]])
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
forall a b. (a -> b) -> a -> b
$
                     PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
segred_pat
    [(SegRedOp ExplicitMemory, [VName], [VName],
  [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
  Lambda ExplicitMemory, Int32)]
-> ((SegRedOp ExplicitMemory, [VName], [VName],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
     Lambda ExplicitMemory, Int32)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOp ExplicitMemory]
-> [[VName]]
-> [[VName]]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
-> [SegRedOpSlug]
-> [Lambda ExplicitMemory]
-> [Int32]
-> [(SegRedOp ExplicitMemory, [VName], [VName],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
     Lambda ExplicitMemory, Int32)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegRedOp ExplicitMemory]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
segred_pes
           [SegRedOpSlug]
slugs [Lambda ExplicitMemory]
reds_op_renamed [Int32
0..]) (((SegRedOp ExplicitMemory, [VName], [VName],
   [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
   Lambda ExplicitMemory, Int32)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegRedOp ExplicitMemory, [VName], [VName],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
     Lambda ExplicitMemory, Int32)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      \(SegRedOp Commutativity
_ Lambda ExplicitMemory
red_op [SubExp]
nes Shape
_,
        [VName]
red_arrs, [VName]
group_res_arrs, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes, SegRedOpSlug
slug, Lambda ExplicitMemory
red_op_renamed, Int32
i) -> do
      let red_acc_params :: [Param (MemInfo SubExp NoUniqueness MemBind)]
red_acc_params = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
red_op
      KernelConstants
-> [PatElem ExplicitMemory]
-> Exp
-> Exp
-> [Exp]
-> Exp
-> Exp
-> SegRedOpSlug
-> [LParam ExplicitMemory]
-> Lambda ExplicitMemory
-> [SubExp]
-> Exp
-> VName
-> Exp
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo KernelConstants
constants [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes (KernelConstants -> Exp
kernelGroupId KernelConstants
constants) Exp
0 [Exp
0] Exp
0
        (KernelConstants -> Exp
kernelNumGroups KernelConstants
constants) SegRedOpSlug
slug [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
red_acc_params Lambda ExplicitMemory
red_op_renamed [SubExp]
nes
        Exp
1 VName
counter (PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
i)
        VName
sync_arr [VName]
group_res_arrs [VName]
red_arrs

smallSegmentsReduction :: Pattern ExplicitMemory
                       -> Count NumGroups SubExp -> Count GroupSize SubExp
                       -> SegSpace
                       -> [SegRedOp ExplicitMemory]
                       -> DoSegBody
                       -> CallKernelGen ()
smallSegmentsReduction :: Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction (Pattern [PatElem ExplicitMemory]
_ [PatElem ExplicitMemory]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegRedOp ExplicitMemory]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [Exp]
dims' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims

  let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
  -- Careful to avoid division by zero now.
  VName
segment_size_nonzero_v <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"segment_size_nonzero" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
                            BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 Exp
segment_size

  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
  VName
num_threads <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_threads" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
  let segment_size_nonzero :: Exp
segment_size_nonzero = VName -> PrimType -> Exp
Imp.var VName
segment_size_nonzero_v PrimType
int32
      num_segments :: Exp
num_segments = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims'
      segments_per_group :: Exp
segments_per_group = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
segment_size_nonzero
      required_groups :: Exp
required_groups = Exp
num_segments Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Exp
segments_per_group

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-small" Maybe Exp
forall a. Maybe a
Nothing
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
num_segments
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
segment_size
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segments_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
segments_per_group
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"required_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
required_groups

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_small" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [[VName]]
reds_arrs <- (SegRedOp ExplicitMemory -> InKernelGen [VName])
-> [SegRedOp ExplicitMemory]
-> ImpM ExplicitMemory KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegRedOp ExplicitMemory -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
num_threads)) [SegRedOp ExplicitMemory]
reds

    -- We probably do not have enough actual workgroups to cover the
    -- entire iteration space.  Some groups thus have to perform double
    -- duty; we put an outer loop to accomplish this.
    SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt Exp
required_groups ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
group_id_var' -> do
      let group_id' :: Exp
group_id' = VName -> Exp
Imp.vi32 VName
group_id_var'
      -- Compute the 'n' input indices.  The outer 'n-1' correspond to
      -- the segment ID, and are computed from the group id.  The inner
      -- is computed from the local thread id, and may be out-of-bounds.
      let ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
          segment_index :: Exp
segment_index = (Exp
ltid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
segment_size_nonzero) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ (Exp
group_id' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segments_per_group)
          index_within_segment :: Exp
index_within_segment = Exp
ltid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size

      (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids) ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims') Exp
segment_index
      VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) Exp
index_within_segment

      let out_of_bounds :: InKernelGen ()
out_of_bounds =
            [(SegRedOp ExplicitMemory, [VName])]
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOp ExplicitMemory]
-> [[VName]] -> [(SegRedOp ExplicitMemory, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegRedOp ExplicitMemory]
reds [[VName]]
reds_arrs) (((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegRedOp Commutativity
_ Lambda ExplicitMemory
_ [SubExp]
nes Shape
_, [VName]
red_arrs) ->
            [(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
red_arrs [SubExp]
nes) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
ne) ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp
ltid] SubExp
ne []

          in_bounds :: InKernelGen ()
in_bounds =
            DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [Exp])]
red_res ->
            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save results to be reduced" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
            let red_dests :: [(VName, [Exp])]
red_dests = [VName] -> [[Exp]] -> [(VName, [Exp])]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs) ([[Exp]] -> [(VName, [Exp])]) -> [[Exp]] -> [(VName, [Exp])]
forall a b. (a -> b) -> a -> b
$ [Exp] -> [[Exp]]
forall a. a -> [a]
repeat [Exp
ltid]
            [((VName, [Exp]), (SubExp, [Exp]))]
-> (((VName, [Exp]), (SubExp, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])]
-> [(SubExp, [Exp])] -> [((VName, [Exp]), (SubExp, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [Exp])]
red_dests [(SubExp, [Exp])]
red_res) ((((VName, [Exp]), (SubExp, [Exp])) -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [Exp]), (SubExp, [Exp])) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
d,[Exp]
d_is), (SubExp
res, [Exp]
res_is)) ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
d [Exp]
d_is SubExp
res [Exp]
res_is

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply map function if in bounds" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
segment_size Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
             [(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [SubExp]
dims) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
             Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
segment_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segments_per_group) InKernelGen ()
in_bounds InKernelGen ()
out_of_bounds

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal -- Also implicitly barrier.

      let crossesSegment :: Exp -> Exp -> Exp
crossesSegment Exp
from Exp
to = (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size)
      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
segment_size Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform segmented scan to imitate reduction" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(SegRedOp ExplicitMemory, [VName])]
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOp ExplicitMemory]
-> [[VName]] -> [(SegRedOp ExplicitMemory, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegRedOp ExplicitMemory]
reds [[VName]]
reds_arrs) (((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegRedOp Commutativity
_ Lambda ExplicitMemory
red_op [SubExp]
_ Shape
_, [VName]
red_arrs) ->
        Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just Exp -> Exp -> Exp
crossesSegment) (VName -> Exp
Imp.vi32 VName
num_threads)
        (Exp
segment_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
segments_per_group) Lambda ExplicitMemory
red_op [VName]
red_arrs

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save final values of segments" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
group_id' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segments_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
num_segments Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
               Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
segments_per_group) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
segred_pes ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) -> do
        -- Figure out which segment result this thread should write...
        let flat_segment_index :: Exp
flat_segment_index = Exp
group_id' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segments_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid
            gtids' :: [Exp]
gtids' = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims') Exp
flat_segment_index
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
gtids'
                        (VName -> SubExp
Var VName
arr) [(Exp
ltidExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
segment_size_nonzero Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]

      -- Finally another barrier, because we will be writing to the
      -- local memory array first thing in the next iteration.
      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

largeSegmentsReduction :: Pattern ExplicitMemory
                       -> Count NumGroups SubExp -> Count GroupSize SubExp
                       -> SegSpace
                       -> [SegRedOp ExplicitMemory]
                       -> DoSegBody
                       -> CallKernelGen ()
largeSegmentsReduction :: Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction Pattern ExplicitMemory
segred_pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegRedOp ExplicitMemory]
reds DoSegBody
body = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [Exp]
dims' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
  let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
      num_segments :: Exp
num_segments = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims'

  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size

  let (Exp
groups_per_segment, Count Elements Exp
elems_per_thread) =
        Exp
-> Exp
-> Count NumGroups Exp
-> Count GroupSize Exp
-> (Exp, Count Elements Exp)
groupsPerSegmentAndElementsPerThread Exp
segment_size Exp
num_segments
        Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size'
  VName
virt_num_groups <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"vit_num_groups" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
    Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_segments

  VName
num_threads <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_threads" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

  VName
threads_per_segment <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"thread_per_segment" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
    Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed-large" Maybe Exp
forall a. Maybe a
Nothing
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
num_segments
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
segment_size
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"virt_num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
virt_num_groups
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_groups" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count NumGroups Exp
num_groups'
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"group_size" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count GroupSize Exp
group_size'
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"elems_per_thread" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elems_per_thread
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"groups_per_segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
groups_per_segment

  [[VName]]
reds_group_res_arrs <- Count NumGroups SubExp
-> Count GroupSize SubExp
-> [SegRedOp ExplicitMemory]
-> CallKernelGen [[VName]]
groupResultArrays (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (VName -> SubExp
Var VName
virt_num_groups)) Count GroupSize SubExp
group_size [SegRedOp ExplicitMemory]
reds

  -- In principle we should have a counter for every segment.  Since
  -- the number of segments is a dynamic quantity, we would have to
  -- allocate and zero out an array here, which is expensive.
  -- However, we exploit the fact that the number of segments being
  -- reduced at any point in time is limited by the number of
  -- workgroups. If we bound the number of workgroups, we can get away
  -- with using that many counters.  FIXME: Is this limit checked
  -- anywhere?  There are other places in the compiler that will fail
  -- if the group count exceeds the maximum group size, which is at
  -- most 1024 anyway.
  let num_counters :: Int
num_counters = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
maxNumOps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024
  VName
counter <-
    String
-> Space
-> PrimType
-> ArrayContents
-> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ArrayContents -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
    Int -> ArrayContents
Imp.ArrayZeros Int
num_counters

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segred_large" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [[VName]]
reds_arrs <- (SegRedOp ExplicitMemory -> InKernelGen [VName])
-> [SegRedOp ExplicitMemory]
-> ImpM ExplicitMemory KernelEnv KernelOp [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Count GroupSize SubExp
-> SubExp -> SegRedOp ExplicitMemory -> InKernelGen [VName]
intermediateArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
num_threads)) [SegRedOp ExplicitMemory]
reds
    VName
sync_arr <- String
-> PrimType
-> Shape
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"sync_arr" PrimType
Bool ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

    -- We probably do not have enough actual workgroups to cover the
    -- entire iteration space.  Some groups thus have to perform double
    -- duty; we put an outer loop to accomplish this.
    SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (VName -> Exp
Imp.vi32 VName
virt_num_groups) ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
group_id_var -> do
      let segment_gtids :: [VName]
segment_gtids = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids
          group_id :: Exp
group_id = VName -> Exp
Imp.vi32 VName
group_id_var
          flat_segment_id :: Exp
flat_segment_id = Exp
group_id Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
groups_per_segment
          local_tid :: Exp
local_tid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants

          global_tid :: Exp
global_tid = (Exp
group_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
local_tid)
                       Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` (Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
groups_per_segment)
          w :: SubExp
w = [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
dims
          first_group_for_segment :: Exp
first_group_for_segment = Exp
flat_segment_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
groups_per_segment

      (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
segment_gtids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims') Exp
flat_segment_id
      VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ ([VName] -> VName
forall a. [a] -> a
last [VName]
gtids) PrimType
int32
      Count Elements Exp
num_elements <- Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
-> ImpM ExplicitMemory KernelEnv KernelOp (Count Elements Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
w

      [SegRedOpSlug]
slugs <- ((SegRedOp ExplicitMemory, [VName], [VName])
 -> ImpM ExplicitMemory KernelEnv KernelOp SegRedOpSlug)
-> [(SegRedOp ExplicitMemory, [VName], [VName])]
-> ImpM ExplicitMemory KernelEnv KernelOp [SegRedOpSlug]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Exp
-> Exp
-> (SegRedOp ExplicitMemory, [VName], [VName])
-> ImpM ExplicitMemory KernelEnv KernelOp SegRedOpSlug
segRedOpSlug Exp
local_tid Exp
group_id) ([(SegRedOp ExplicitMemory, [VName], [VName])]
 -> ImpM ExplicitMemory KernelEnv KernelOp [SegRedOpSlug])
-> [(SegRedOp ExplicitMemory, [VName], [VName])]
-> ImpM ExplicitMemory KernelEnv KernelOp [SegRedOpSlug]
forall a b. (a -> b) -> a -> b
$
               [SegRedOp ExplicitMemory]
-> [[VName]]
-> [[VName]]
-> [(SegRedOp ExplicitMemory, [VName], [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegRedOp ExplicitMemory]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs
      [Lambda ExplicitMemory]
reds_op_renamed <-
        KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegRedOpSlug]
-> DoSegBody
-> InKernelGen [Lambda ExplicitMemory]
reductionStageOne KernelConstants
constants ([VName] -> [Exp] -> [(VName, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [Exp]
dims') Count Elements Exp
num_elements
        Exp
global_tid Count Elements Exp
elems_per_thread VName
threads_per_segment
        [SegRedOpSlug]
slugs DoSegBody
body

      let segred_pes :: [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
segred_pes = [Int]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegRedOp ExplicitMemory -> Int)
-> [SegRedOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegRedOp ExplicitMemory -> [SubExp])
-> SegRedOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp ExplicitMemory -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral) [SegRedOp ExplicitMemory]
reds) ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
 -> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]])
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
forall a b. (a -> b) -> a -> b
$
                       PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
segred_pat

          multiple_groups_per_segment :: InKernelGen ()
multiple_groups_per_segment =
            [(SegRedOp ExplicitMemory, [VName], [VName],
  [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
  Lambda ExplicitMemory, Int32)]
-> ((SegRedOp ExplicitMemory, [VName], [VName],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
     Lambda ExplicitMemory, Int32)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOp ExplicitMemory]
-> [[VName]]
-> [[VName]]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
-> [SegRedOpSlug]
-> [Lambda ExplicitMemory]
-> [Int32]
-> [(SegRedOp ExplicitMemory, [VName], [VName],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
     Lambda ExplicitMemory, Int32)]
forall a b c d e f g.
[a]
-> [b]
-> [c]
-> [d]
-> [e]
-> [f]
-> [g]
-> [(a, b, c, d, e, f, g)]
zip7 [SegRedOp ExplicitMemory]
reds [[VName]]
reds_arrs [[VName]]
reds_group_res_arrs [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
segred_pes
                   [SegRedOpSlug]
slugs [Lambda ExplicitMemory]
reds_op_renamed [Int32
0..]) (((SegRedOp ExplicitMemory, [VName], [VName],
   [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
   Lambda ExplicitMemory, Int32)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegRedOp ExplicitMemory, [VName], [VName],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)], SegRedOpSlug,
     Lambda ExplicitMemory, Int32)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            \(SegRedOp Commutativity
_ Lambda ExplicitMemory
red_op [SubExp]
nes Shape
_, [VName]
red_arrs, [VName]
group_res_arrs, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes,
              SegRedOpSlug
slug, Lambda ExplicitMemory
red_op_renamed, Int32
i) -> do
              let red_acc_params :: [Param (MemInfo SubExp NoUniqueness MemBind)]
red_acc_params = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
red_op
              KernelConstants
-> [PatElem ExplicitMemory]
-> Exp
-> Exp
-> [Exp]
-> Exp
-> Exp
-> SegRedOpSlug
-> [LParam ExplicitMemory]
-> Lambda ExplicitMemory
-> [SubExp]
-> Exp
-> VName
-> Exp
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo KernelConstants
constants [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes
                Exp
group_id Exp
flat_segment_id ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
segment_gtids)
                Exp
first_group_for_segment Exp
groups_per_segment
                SegRedOpSlug
slug [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
red_acc_params Lambda ExplicitMemory
red_op_renamed [SubExp]
nes
                (Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters) VName
counter (PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
i)
                VName
sync_arr [VName]
group_res_arrs [VName]
red_arrs

          one_group_per_segment :: InKernelGen ()
one_group_per_segment =
            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"first thread in group saves final result to memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(SegRedOpSlug, [PatElemT (MemInfo SubExp NoUniqueness MemBind)])]
-> ((SegRedOpSlug,
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOpSlug]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
-> [(SegRedOpSlug,
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegRedOpSlug]
slugs [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
segred_pes) (((SegRedOpSlug, [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegRedOpSlug,
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegRedOpSlug
slug, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) ->
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [(PatElemT (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind),
     (VName, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, [Exp])]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind),
     (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes (SegRedOpSlug -> [(VName, [Exp])]
slugAccs SegRedOpSlug
slug)) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind),
     (VName, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
v, (VName
acc, [Exp]
acc_is)) ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
v) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
segment_gtids) (VName -> SubExp
Var VName
acc) [Exp]
acc_is

      Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
groups_per_segment Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) InKernelGen ()
one_group_per_segment InKernelGen ()
multiple_groups_per_segment

-- Careful to avoid division by zero here.  We have at least one group
-- per segment.
groupsPerSegmentAndElementsPerThread :: Imp.Exp -> Imp.Exp
                                     -> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
                                     -> (Imp.Exp, Imp.Count Imp.Elements Imp.Exp)
groupsPerSegmentAndElementsPerThread :: Exp
-> Exp
-> Count NumGroups Exp
-> Count GroupSize Exp
-> (Exp, Count Elements Exp)
groupsPerSegmentAndElementsPerThread Exp
segment_size Exp
num_segments Count NumGroups Exp
num_groups_hint Count GroupSize Exp
group_size =
  let groups_per_segment :: Exp
groups_per_segment =
        Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups_hint Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 Exp
num_segments
      elements_per_thread :: Exp
elements_per_thread =
        Exp
segment_size Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` (Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
groups_per_segment)
  in (Exp
groups_per_segment, Exp -> Count Elements Exp
Imp.elements Exp
elements_per_thread)

-- | A SegRedOp with auxiliary information.
data SegRedOpSlug =
  SegRedOpSlug
  { SegRedOpSlug -> SegRedOp ExplicitMemory
slugOp :: SegRedOp ExplicitMemory
  , SegRedOpSlug -> [VName]
slugArrs :: [VName]
    -- ^ The arrays used for computing the intra-group reduction
    -- (either local or global memory).
  , SegRedOpSlug -> [(VName, [Exp])]
slugAccs :: [(VName, [Imp.Exp])]
    -- ^ Places to store accumulator in stage 1 reduction.
  }

slugBody :: SegRedOpSlug -> Body ExplicitMemory
slugBody :: SegRedOpSlug -> Body ExplicitMemory
slugBody = Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody (Lambda ExplicitMemory -> Body ExplicitMemory)
-> (SegRedOpSlug -> Lambda ExplicitMemory)
-> SegRedOpSlug
-> Body ExplicitMemory
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda (SegRedOp ExplicitMemory -> Lambda ExplicitMemory)
-> (SegRedOpSlug -> SegRedOp ExplicitMemory)
-> SegRedOpSlug
-> Lambda ExplicitMemory
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOpSlug -> SegRedOp ExplicitMemory
slugOp

slugParams :: SegRedOpSlug -> [LParam ExplicitMemory]
slugParams :: SegRedOpSlug -> [LParam ExplicitMemory]
slugParams = Lambda ExplicitMemory
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (Lambda ExplicitMemory
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> (SegRedOpSlug -> Lambda ExplicitMemory)
-> SegRedOpSlug
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda (SegRedOp ExplicitMemory -> Lambda ExplicitMemory)
-> (SegRedOpSlug -> SegRedOp ExplicitMemory)
-> SegRedOpSlug
-> Lambda ExplicitMemory
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOpSlug -> SegRedOp ExplicitMemory
slugOp

slugNeutral :: SegRedOpSlug -> [SubExp]
slugNeutral :: SegRedOpSlug -> [SubExp]
slugNeutral = SegRedOp ExplicitMemory -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral (SegRedOp ExplicitMemory -> [SubExp])
-> (SegRedOpSlug -> SegRedOp ExplicitMemory)
-> SegRedOpSlug
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOpSlug -> SegRedOp ExplicitMemory
slugOp

slugShape :: SegRedOpSlug -> Shape
slugShape :: SegRedOpSlug -> Shape
slugShape = SegRedOp ExplicitMemory -> Shape
forall lore. SegRedOp lore -> Shape
segRedShape (SegRedOp ExplicitMemory -> Shape)
-> (SegRedOpSlug -> SegRedOp ExplicitMemory)
-> SegRedOpSlug
-> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOpSlug -> SegRedOp ExplicitMemory
slugOp

slugsComm :: [SegRedOpSlug] -> Commutativity
slugsComm :: [SegRedOpSlug] -> Commutativity
slugsComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ([Commutativity] -> Commutativity)
-> ([SegRedOpSlug] -> [Commutativity])
-> [SegRedOpSlug]
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegRedOpSlug -> Commutativity)
-> [SegRedOpSlug] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map (SegRedOp ExplicitMemory -> Commutativity
forall lore. SegRedOp lore -> Commutativity
segRedComm (SegRedOp ExplicitMemory -> Commutativity)
-> (SegRedOpSlug -> SegRedOp ExplicitMemory)
-> SegRedOpSlug
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOpSlug -> SegRedOp ExplicitMemory
slugOp)

accParams, nextParams :: SegRedOpSlug -> [LParam ExplicitMemory]
accParams :: SegRedOpSlug -> [LParam ExplicitMemory]
accParams SegRedOpSlug
slug = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegRedOpSlug -> [SubExp]
slugNeutral SegRedOpSlug
slug)) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ SegRedOpSlug -> [LParam ExplicitMemory]
slugParams SegRedOpSlug
slug
nextParams :: SegRedOpSlug -> [LParam ExplicitMemory]
nextParams SegRedOpSlug
slug = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegRedOpSlug -> [SubExp]
slugNeutral SegRedOpSlug
slug)) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ SegRedOpSlug -> [LParam ExplicitMemory]
slugParams SegRedOpSlug
slug

segRedOpSlug :: Imp.Exp -> Imp.Exp -> (SegRedOp ExplicitMemory, [VName], [VName]) -> InKernelGen SegRedOpSlug
segRedOpSlug :: Exp
-> Exp
-> (SegRedOp ExplicitMemory, [VName], [VName])
-> ImpM ExplicitMemory KernelEnv KernelOp SegRedOpSlug
segRedOpSlug Exp
local_tid Exp
group_id (SegRedOp ExplicitMemory
op, [VName]
group_res_arrs, [VName]
param_arrs) =
  SegRedOp ExplicitMemory
-> [VName] -> [(VName, [Exp])] -> SegRedOpSlug
SegRedOpSlug SegRedOp ExplicitMemory
op [VName]
group_res_arrs ([(VName, [Exp])] -> SegRedOpSlug)
-> ImpM ExplicitMemory KernelEnv KernelOp [(VName, [Exp])]
-> ImpM ExplicitMemory KernelEnv KernelOp SegRedOpSlug
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  (Param (MemInfo SubExp NoUniqueness MemBind)
 -> VName -> ImpM ExplicitMemory KernelEnv KernelOp (VName, [Exp]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> ImpM ExplicitMemory KernelEnv KernelOp [(VName, [Exp])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM ExplicitMemory KernelEnv KernelOp (VName, [Exp])
mkAcc (Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda SegRedOp ExplicitMemory
op)) [VName]
param_arrs
  where mkAcc :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM ExplicitMemory KernelEnv KernelOp (VName, [Exp])
mkAcc Param (MemInfo SubExp NoUniqueness MemBind)
p VName
param_arr
          | Prim PrimType
t <- Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase Shape NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p,
            Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegRedOp ExplicitMemory -> Shape
forall lore. SegRedOp lore -> Shape
segRedShape SegRedOp ExplicitMemory
op) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = do
              VName
acc <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim (VName -> String
baseString (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_acc") PrimType
t
              (VName, [Exp])
-> ImpM ExplicitMemory KernelEnv KernelOp (VName, [Exp])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
acc, [])
          | Bool
otherwise =
              (VName, [Exp])
-> ImpM ExplicitMemory KernelEnv KernelOp (VName, [Exp])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
param_arr, [Exp
local_tid, Exp
group_id])

reductionStageZero :: KernelConstants
                   -> [(VName, Imp.Exp)]
                   -> Imp.Count Imp.Elements Imp.Exp
                   -> Imp.Exp
                   -> Imp.Count Imp.Elements Imp.Exp
                   -> VName
                   -> [SegRedOpSlug]
                   -> DoSegBody
                   -> InKernelGen ([Lambda ExplicitMemory], InKernelGen ())
reductionStageZero :: KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegRedOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda ExplicitMemory], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, Exp)]
ispace Count Elements Exp
num_elements Exp
global_tid Count Elements Exp
elems_per_thread VName
threads_per_segment [SegRedOpSlug]
slugs DoSegBody
body = do
  let ([VName]
gtids, [Exp]
_dims) = [(VName, Exp)] -> ([VName], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, Exp)]
ispace
      gtid :: VName
gtid = [VName] -> VName
forall a. [a] -> a
last [VName]
gtids
      local_tid :: Exp
local_tid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants

  -- Figure out how many elements this thread should process.
  VName
chunk_size <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"chunk_size" PrimType
int32
  let ordering :: SplitOrdering
ordering = case [SegRedOpSlug] -> Commutativity
slugsComm [SegRedOpSlug]
slugs of
                   Commutativity
Commutative -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
threads_per_segment
                   Commutativity
Noncommutative -> SplitOrdering
SplitContiguous
  SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> InKernelGen ()
forall lore r op.
SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
computeThreadChunkSize SplitOrdering
ordering Exp
global_tid Count Elements Exp
elems_per_thread Count Elements Exp
num_elements VName
chunk_size

  Maybe (Exp ExplicitMemory)
-> Scope ExplicitMemory -> InKernelGen ()
forall lore r op.
Maybe (Exp lore) -> Scope ExplicitMemory -> ImpM lore r op ()
dScope Maybe (Exp ExplicitMemory)
forall a. Maybe a
Nothing (Scope ExplicitMemory -> InKernelGen ())
-> Scope ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> Scope ExplicitMemory)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ (SegRedOpSlug -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [SegRedOpSlug] -> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegRedOpSlug -> [LParam ExplicitMemory]
SegRedOpSlug -> [Param (MemInfo SubExp NoUniqueness MemBind)]
slugParams [SegRedOpSlug]
slugs

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"neutral-initialise the accumulators" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    [SegRedOpSlug]
-> (SegRedOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegRedOpSlug]
slugs ((SegRedOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegRedOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegRedOpSlug
slug ->
    [((VName, [Exp]), SubExp)]
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])] -> [SubExp] -> [((VName, [Exp]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegRedOpSlug -> [(VName, [Exp])]
slugAccs SegRedOpSlug
slug) (SegRedOpSlug -> [SubExp]
slugNeutral SegRedOpSlug
slug)) ((((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [Exp]
acc_is), SubExp
ne) ->
    Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegRedOpSlug -> Shape
slugShape SegRedOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is ->
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
acc ([Exp]
acc_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is) SubExp
ne []

  [Lambda ExplicitMemory]
slugs_op_renamed <- (SegRedOpSlug
 -> ImpM ExplicitMemory KernelEnv KernelOp (Lambda ExplicitMemory))
-> [SegRedOpSlug] -> InKernelGen [Lambda ExplicitMemory]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Lambda ExplicitMemory
-> ImpM ExplicitMemory KernelEnv KernelOp (Lambda ExplicitMemory)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (Lambda ExplicitMemory
 -> ImpM ExplicitMemory KernelEnv KernelOp (Lambda ExplicitMemory))
-> (SegRedOpSlug -> Lambda ExplicitMemory)
-> SegRedOpSlug
-> ImpM ExplicitMemory KernelEnv KernelOp (Lambda ExplicitMemory)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda (SegRedOp ExplicitMemory -> Lambda ExplicitMemory)
-> (SegRedOpSlug -> SegRedOp ExplicitMemory)
-> SegRedOpSlug
-> Lambda ExplicitMemory
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOpSlug -> SegRedOp ExplicitMemory
slugOp) [SegRedOpSlug]
slugs

  let doTheReduction :: InKernelGen ()
doTheReduction =
        [(Lambda ExplicitMemory, SegRedOpSlug)]
-> ((Lambda ExplicitMemory, SegRedOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Lambda ExplicitMemory]
-> [SegRedOpSlug] -> [(Lambda ExplicitMemory, SegRedOpSlug)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Lambda ExplicitMemory]
slugs_op_renamed [SegRedOpSlug]
slugs) (((Lambda ExplicitMemory, SegRedOpSlug) -> InKernelGen ())
 -> InKernelGen ())
-> ((Lambda ExplicitMemory, SegRedOpSlug) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Lambda ExplicitMemory
slug_op_renamed, SegRedOpSlug
slug) ->
        Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegRedOpSlug -> Shape
slugShape SegRedOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is -> do
          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"to reduce current chunk, first store our result in memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
            [(Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, [Exp])]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegRedOpSlug -> [LParam ExplicitMemory]
slugParams SegRedOpSlug
slug) (SegRedOpSlug -> [(VName, [Exp])]
slugAccs SegRedOpSlug
slug)) (((Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, (VName
acc, [Exp]
acc_is)) ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
acc) ([Exp]
acc_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is)

            [(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegRedOpSlug -> [VName]
slugArrs SegRedOpSlug
slug) (SegRedOpSlug -> [LParam ExplicitMemory]
slugParams SegRedOpSlug
slug)) (((VName, Param (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
              Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase Shape NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []

          KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal -- Also implicitly barrier.

          Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduce (KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) Lambda ExplicitMemory
slug_op_renamed (SegRedOpSlug -> [VName]
slugArrs SegRedOpSlug
slug)

          KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread saves the result in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [((VName, [Exp]), Param (MemInfo SubExp NoUniqueness MemBind))]
-> (((VName, [Exp]), Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [((VName, [Exp]), Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegRedOpSlug -> [(VName, [Exp])]
slugAccs SegRedOpSlug
slug) (Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
slug_op_renamed)) ((((VName, [Exp]), Param (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> (((VName, [Exp]), Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [Exp]
acc_is), Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
acc ([Exp]
acc_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []

  -- If this is a non-commutative reduction, each thread must run the
  -- loop the same number of iterations, because we will be performing
  -- a group-wide reduction in there.
  let comm :: Commutativity
comm = [SegRedOpSlug] -> Commutativity
slugsComm [SegRedOpSlug]
slugs
      (Exp
bound, InKernelGen () -> InKernelGen ()
check_bounds) =
        case Commutativity
comm of
          Commutativity
Commutative -> (VName -> PrimType -> Exp
Imp.var VName
chunk_size PrimType
int32, InKernelGen () -> InKernelGen ()
forall a. a -> a
id)
          Commutativity
Noncommutative -> (Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elems_per_thread,
                             Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (VName -> PrimType -> Exp
Imp.var VName
gtid PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements))

  String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
bound ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
    VName
gtid VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<--
      case Commutativity
comm of
        Commutativity
Commutative ->
          Exp
global_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
          VName -> PrimType -> Exp
Imp.var VName
threads_per_segment PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
i
        Commutativity
Noncommutative ->
          let index_in_segment :: Exp
index_in_segment = Exp
global_tid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
          in Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
             (Exp
index_in_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elems_per_thread Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
*
             KernelConstants -> Exp
kernelGroupSize KernelConstants
constants

    InKernelGen () -> InKernelGen ()
check_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply map function" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      DoSegBody
body DoSegBody -> DoSegBody
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [Exp])]
all_red_res -> do

      let slugs_res :: [[(SubExp, [Exp])]]
slugs_res = [Int] -> [(SubExp, [Exp])] -> [[(SubExp, [Exp])]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegRedOpSlug -> Int) -> [SegRedOpSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegRedOpSlug -> [SubExp]) -> SegRedOpSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOpSlug -> [SubExp]
slugNeutral) [SegRedOpSlug]
slugs) [(SubExp, [Exp])]
all_red_res

      [(SegRedOpSlug, [(SubExp, [Exp])])]
-> ((SegRedOpSlug, [(SubExp, [Exp])]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOpSlug]
-> [[(SubExp, [Exp])]] -> [(SegRedOpSlug, [(SubExp, [Exp])])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegRedOpSlug]
slugs [[(SubExp, [Exp])]]
slugs_res) (((SegRedOpSlug, [(SubExp, [Exp])]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegRedOpSlug, [(SubExp, [Exp])]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegRedOpSlug
slug, [(SubExp, [Exp])]
red_res) ->
        Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegRedOpSlug -> Shape
slugShape SegRedOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is -> do
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, [Exp])]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegRedOpSlug -> [LParam ExplicitMemory]
accParams SegRedOpSlug
slug) (SegRedOpSlug -> [(VName, [Exp])]
slugAccs SegRedOpSlug
slug)) (((Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, (VName
acc, [Exp]
acc_is)) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
acc) ([Exp]
acc_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
vec_is)
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"load new values" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(Param (MemInfo SubExp NoUniqueness MemBind), (SubExp, [Exp]))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), (SubExp, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(SubExp, [Exp])]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), (SubExp, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegRedOpSlug -> [LParam ExplicitMemory]
nextParams SegRedOpSlug
slug) [(SubExp, [Exp])]
red_res) (((Param (MemInfo SubExp NoUniqueness MemBind), (SubExp, [Exp]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), (SubExp, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, (SubExp
res, [Exp]
res_is)) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
res ([Exp]
res_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
vec_is)
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"apply reduction operator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body ExplicitMemory -> Stms ExplicitMemory
forall lore. BodyT lore -> Stms lore
bodyStms (Body ExplicitMemory -> Stms ExplicitMemory)
-> Body ExplicitMemory -> Stms ExplicitMemory
forall a b. (a -> b) -> a -> b
$ SegRedOpSlug -> Body ExplicitMemory
slugBody SegRedOpSlug
slug) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"store in accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [((VName, [Exp]), SubExp)]
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])] -> [SubExp] -> [((VName, [Exp]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip
                  (SegRedOpSlug -> [(VName, [Exp])]
slugAccs SegRedOpSlug
slug)
                  (Body ExplicitMemory -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body ExplicitMemory -> [SubExp])
-> Body ExplicitMemory -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegRedOpSlug -> Body ExplicitMemory
slugBody SegRedOpSlug
slug)) ((((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [Exp]
acc_is), SubExp
se) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
acc ([Exp]
acc_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
vec_is) SubExp
se []

    case Commutativity
comm of
      Commutativity
Noncommutative -> do
        InKernelGen ()
doTheReduction
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"first thread keeps accumulator; others reset to neutral element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          let reset_to_neutral :: InKernelGen ()
reset_to_neutral =
                [SegRedOpSlug]
-> (SegRedOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegRedOpSlug]
slugs ((SegRedOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegRedOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegRedOpSlug
slug ->
                [((VName, [Exp]), SubExp)]
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [Exp])] -> [SubExp] -> [((VName, [Exp]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegRedOpSlug -> [(VName, [Exp])]
slugAccs SegRedOpSlug
slug) (SegRedOpSlug -> [SubExp]
slugNeutral SegRedOpSlug
slug)) ((((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, [Exp]), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \((VName
acc, [Exp]
acc_is), SubExp
ne) ->
                Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegRedOpSlug -> Shape
slugShape SegRedOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is ->
                VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
acc ([Exp]
acc_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is) SubExp
ne []
          Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) InKernelGen ()
reset_to_neutral
      Commutativity
_ -> () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  ([Lambda ExplicitMemory], InKernelGen ())
-> InKernelGen ([Lambda ExplicitMemory], InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([Lambda ExplicitMemory]
slugs_op_renamed, InKernelGen ()
doTheReduction)

reductionStageOne :: KernelConstants
                  -> [(VName, Imp.Exp)]
                  -> Imp.Count Imp.Elements Imp.Exp
                  -> Imp.Exp
                  -> Imp.Count Imp.Elements Imp.Exp
                  -> VName
                  -> [SegRedOpSlug]
                  -> DoSegBody
                  -> InKernelGen [Lambda ExplicitMemory]
reductionStageOne :: KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegRedOpSlug]
-> DoSegBody
-> InKernelGen [Lambda ExplicitMemory]
reductionStageOne KernelConstants
constants [(VName, Exp)]
ispace Count Elements Exp
num_elements Exp
global_tid Count Elements Exp
elems_per_thread VName
threads_per_segment [SegRedOpSlug]
slugs DoSegBody
body = do
  ([Lambda ExplicitMemory]
slugs_op_renamed, InKernelGen ()
doTheReduction) <-
    KernelConstants
-> [(VName, Exp)]
-> Count Elements Exp
-> Exp
-> Count Elements Exp
-> VName
-> [SegRedOpSlug]
-> DoSegBody
-> InKernelGen ([Lambda ExplicitMemory], InKernelGen ())
reductionStageZero KernelConstants
constants [(VName, Exp)]
ispace Count Elements Exp
num_elements Exp
global_tid Count Elements Exp
elems_per_thread VName
threads_per_segment [SegRedOpSlug]
slugs DoSegBody
body

  case [SegRedOpSlug] -> Commutativity
slugsComm [SegRedOpSlug]
slugs of
    Commutativity
Noncommutative ->
      [SegRedOpSlug]
-> (SegRedOpSlug -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegRedOpSlug]
slugs ((SegRedOpSlug -> InKernelGen ()) -> InKernelGen ())
-> (SegRedOpSlug -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegRedOpSlug
slug ->
      [(Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, [Exp])]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegRedOpSlug -> [LParam ExplicitMemory]
accParams SegRedOpSlug
slug) (SegRedOpSlug -> [(VName, [Exp])]
slugAccs SegRedOpSlug
slug)) (((Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), (VName, [Exp]))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, (VName
acc, [Exp]
acc_is)) ->
      VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
acc) [Exp]
acc_is
    Commutativity
_ -> InKernelGen ()
doTheReduction

  [Lambda ExplicitMemory] -> InKernelGen [Lambda ExplicitMemory]
forall (m :: * -> *) a. Monad m => a -> m a
return [Lambda ExplicitMemory]
slugs_op_renamed

reductionStageTwo :: KernelConstants
                  -> [PatElem ExplicitMemory]
                  -> Imp.Exp
                  -> Imp.Exp
                  -> [Imp.Exp]
                  -> Imp.Exp
                  -> Imp.Exp
                  -> SegRedOpSlug
                  -> [LParam ExplicitMemory]
                  -> Lambda ExplicitMemory -> [SubExp]
                  -> Imp.Exp -> VName -> Imp.Exp -> VName -> [VName] -> [VName]
                  -> InKernelGen ()
reductionStageTwo :: KernelConstants
-> [PatElem ExplicitMemory]
-> Exp
-> Exp
-> [Exp]
-> Exp
-> Exp
-> SegRedOpSlug
-> [LParam ExplicitMemory]
-> Lambda ExplicitMemory
-> [SubExp]
-> Exp
-> VName
-> Exp
-> VName
-> [VName]
-> [VName]
-> InKernelGen ()
reductionStageTwo KernelConstants
constants [PatElem ExplicitMemory]
segred_pes
                  Exp
group_id Exp
flat_segment_id [Exp]
segment_gtids Exp
first_group_for_segment Exp
groups_per_segment
                  SegRedOpSlug
slug [LParam ExplicitMemory]
red_acc_params
                  Lambda ExplicitMemory
red_op_renamed [SubExp]
nes
                  Exp
num_counters VName
counter Exp
counter_i VName
sync_arr [VName]
group_res_arrs [VName]
red_arrs = do
  let local_tid :: Exp
local_tid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
      group_size :: Exp
group_size = KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
  VName
old_counter <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old_counter" PrimType
int32
  (VName
counter_mem, Space
_, Count Elements Exp
counter_offset) <- VName
-> [Exp]
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
counter [Exp
counter_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_counters Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
                                                               Exp
flat_segment_id Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
num_counters]
  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"first thread in group saves group result to global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    [(VName, (VName, [Exp]))]
-> ((VName, (VName, [Exp])) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Int -> [(VName, (VName, [Exp]))] -> [(VName, (VName, [Exp]))]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([(VName, (VName, [Exp]))] -> [(VName, (VName, [Exp]))])
-> [(VName, (VName, [Exp]))] -> [(VName, (VName, [Exp]))]
forall a b. (a -> b) -> a -> b
$ [VName] -> [(VName, [Exp])] -> [(VName, (VName, [Exp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
group_res_arrs (SegRedOpSlug -> [(VName, [Exp])]
slugAccs SegRedOpSlug
slug)) (((VName, (VName, [Exp])) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, (VName, [Exp])) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, (VName
acc, [Exp]
acc_is)) ->
      VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
v [Exp
0, Exp
group_id] (VName -> SubExp
Var VName
acc) [Exp]
acc_is
    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
    -- Increment the counter, thus stating that our result is
    -- available.
    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$ IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32 VName
old_counter VName
counter_mem Count Elements Exp
counter_offset Exp
1
    -- Now check if we were the last group to write our result.  If
    -- so, it is our responsibility to produce the final result.
    VName -> [Exp] -> Exp -> InKernelGen ()
forall lore r op. VName -> [Exp] -> Exp -> ImpM lore r op ()
sWrite VName
sync_arr [Exp
0] (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
old_counter PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1

  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

  VName
is_last_group <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"is_last_group" PrimType
Bool
  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
is_last_group [] (VName -> SubExp
Var VName
sync_arr) [Exp
0]
  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (VName -> PrimType -> Exp
Imp.var VName
is_last_group PrimType
Bool) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    -- The final group has written its result (and it was
    -- us!), so read in all the group results and perform the
    -- final stage of the reduction.  But first, we reset the
    -- counter so it is ready for next time.  This is done
    -- with an atomic to avoid warnings about write/write
    -- races in oclgrind.
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
      IntType -> VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32 VName
old_counter VName
counter_mem Count Elements Exp
counter_offset (Exp -> AtomicOp) -> Exp -> AtomicOp
forall a b. (a -> b) -> a -> b
$
      Exp -> Exp
forall a. Num a => a -> a
negate Exp
groups_per_segment
    Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (SegRedOpSlug -> Shape
slugShape SegRedOpSlug
slug) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
vec_is -> do
      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"read in the per-group-results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(Param (MemInfo SubExp NoUniqueness MemBind), VName, SubExp,
  VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName, SubExp,
     VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [SubExp]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName, SubExp,
     VName)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
red_acc_params [VName]
red_arrs [SubExp]
nes [VName]
group_res_arrs) (((Param (MemInfo SubExp NoUniqueness MemBind), VName, SubExp,
   VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName, SubExp,
     VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        \(Param (MemInfo SubExp NoUniqueness MemBind)
p, VName
arr, SubExp
ne, VName
group_res_arr) -> do
          let load_group_result :: InKernelGen ()
load_group_result =
                VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []
                (VName -> SubExp
Var VName
group_res_arr) ([Exp
0, Exp
first_group_for_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
local_tid] [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
vec_is)
              load_neutral_element :: InKernelGen ()
load_neutral_element =
                VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []
          Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
groups_per_segment)
            InKernelGen ()
load_group_result InKernelGen ()
load_neutral_element
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase Shape NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"reduce the per-group results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduce Exp
group_size Lambda ExplicitMemory
red_op_renamed [VName]
red_arrs

        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"and back to memory with the final result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(PatElemT (MemInfo SubExp NoUniqueness MemBind),
  Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
segred_pes ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [(PatElemT (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
red_op_renamed) (((PatElemT (MemInfo SubExp NoUniqueness MemBind),
   Param (MemInfo SubExp NoUniqueness MemBind))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ([Exp]
segment_gtids[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
vec_is) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []