{-# LANGUAGE TypeFamilies #-}

-- | Code generation for segmented and non-segmented scans.  Uses a
-- fast single-pass algorithm, but which only works on NVIDIA GPUs and
-- with some constraints on the operator.  We use this when we can.
module Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass (compileSegScan) where

import Control.Monad.Except
import Data.List (zip4)
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, quot)
import Prelude hiding (mod, quot, rem)

xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
  forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
  forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))

alignTo :: IntegralExp a => a -> a -> a
alignTo :: forall a. IntegralExp a => a -> a -> a
alignTo a
x a
a = (a
x forall a. IntegralExp a => a -> a -> a
`divUp` a
a) forall a. Num a => a -> a -> a
* a
a

createLocalArrays ::
  Count GroupSize SubExp ->
  SubExp ->
  [PrimType] ->
  InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays :: Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays (Count SubExp
groupSize) SubExp
m [PrimType]
types = do
  let groupSizeE :: TExp Int64
groupSizeE = SubExp -> TExp Int64
pe64 SubExp
groupSize
      workSize :: TExp Int64
workSize = SubExp -> TExp Int64
pe64 SubExp
m forall a. Num a => a -> a -> a
* TExp Int64
groupSizeE
      prefixArraysSize :: TExp Int64
prefixArraysSize =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> forall a. IntegralExp a => a -> a -> a
alignTo TExp Int64
acc TExp Int64
tySize forall a. Num a => a -> a -> a
+ TExp Int64
tySize forall a. Num a => a -> a -> a
* TExp Int64
groupSizeE) TExp Int64
0 forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
      maxTransposedArraySize :: TExp Int64
maxTransposedArraySize =
        forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\PrimType
ty -> TExp Int64
workSize forall a. Num a => a -> a -> a
* forall a. Num a => PrimType -> a
primByteSize PrimType
ty) [PrimType]
types

      warpSize :: Num a => a
      warpSize :: forall a. Num a => a
warpSize = a
32
      maxWarpExchangeSize :: TExp Int64
maxWarpExchangeSize =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> forall a. IntegralExp a => a -> a -> a
alignTo TExp Int64
acc TExp Int64
tySize forall a. Num a => a -> a -> a
+ TExp Int64
tySize forall a. Num a => a -> a -> a
* forall a. Num a => Integer -> a
fromInteger forall a. Num a => a
warpSize) TExp Int64
0 forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
      maxLookbackSize :: TExp Int64
maxLookbackSize = TExp Int64
maxWarpExchangeSize forall a. Num a => a -> a -> a
+ forall a. Num a => a
warpSize
      size :: Count Bytes (TExp Int64)
size = forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$ TExp Int64
maxLookbackSize forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
prefixArraysSize forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
maxTransposedArraySize

      varTE :: TV Int64 -> TPrimExp Int64 VName
      varTE :: TV Int64 -> TExp Int64
varTE = forall a. a -> TPrimExp Int64 a
le64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k). TV t -> VName
tvVar

  [TExp Int64]
byteOffsets <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TExp Int64
varTE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"byte_offsets") forall a b. (a -> b) -> a -> b
$
      forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TExp Int64
off TExp Int64
tySize -> forall a. IntegralExp a => a -> a -> a
alignTo TExp Int64
off TExp Int64
tySize forall a. Num a => a -> a -> a
+ SubExp -> TExp Int64
pe64 SubExp
groupSize forall a. Num a => a -> a -> a
* TExp Int64
tySize) TExp Int64
0 forall a b. (a -> b) -> a -> b
$
        forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => PrimType -> a
primByteSize [PrimType]
types

  [TExp Int64]
warpByteOffsets <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TExp Int64
varTE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"warp_byte_offset") forall a b. (a -> b) -> a -> b
$
      forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TExp Int64
off TExp Int64
tySize -> forall a. IntegralExp a => a -> a -> a
alignTo TExp Int64
off TExp Int64
tySize forall a. Num a => a -> a -> a
+ forall a. Num a => a
warpSize forall a. Num a => a -> a -> a
* TExp Int64
tySize) forall a. Num a => a
warpSize forall a b. (a -> b) -> a -> b
$
        forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => PrimType -> a
primByteSize [PrimType]
types

  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Allocate reused shared memeory" forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  VName
localMem <- forall {k} (rep :: k) r op.
String -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc String
"local_mem" Count Bytes (TExp Int64)
size (String -> Space
Space String
"local")
  TV Int64
transposeArrayLength <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"trans_arr_len" TExp Int64
workSize

  VName
sharedId <- forall {k} (rep :: k) r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"shared_id" PrimType
int32 (forall d. [d] -> ShapeBase d
Shape [forall v. IsValue v => v -> SubExp
constant (Int32
1 :: Int32)]) VName
localMem

  [VName]
transposedArrays <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
types forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
      forall {k} (rep :: k) r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem
        String
"local_transpose_arr"
        PrimType
ty
        (forall d. [d] -> ShapeBase d
Shape [forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
transposeArrayLength])
        VName
localMem

  [VName]
prefixArrays <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [TExp Int64]
byteOffsets [PrimType]
types) forall a b. (a -> b) -> a -> b
$ \(TExp Int64
off, PrimType
ty) -> do
      let off' :: TExp Int64
off' = TExp Int64
off forall a. IntegralExp a => a -> a -> a
`quot` forall a. Num a => PrimType -> a
primByteSize PrimType
ty
      forall {k} (rep :: k) r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op VName
sArray
        String
"local_prefix_arr"
        PrimType
ty
        (forall d. [d] -> ShapeBase d
Shape [SubExp
groupSize])
        VName
localMem
        forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TExp Int64
off' [SubExp -> TExp Int64
pe64 SubExp
groupSize]

  VName
warpscan <- forall {k} (rep :: k) r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem String
"warpscan" PrimType
int8 (forall d. [d] -> ShapeBase d
Shape [forall v. IsValue v => v -> SubExp
constant (forall a. Num a => a
warpSize :: Int64)]) VName
localMem
  [VName]
warpExchanges <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [TExp Int64]
warpByteOffsets [PrimType]
types) forall a b. (a -> b) -> a -> b
$ \(TExp Int64
off, PrimType
ty) -> do
      let off' :: TExp Int64
off' = TExp Int64
off forall a. IntegralExp a => a -> a -> a
`quot` forall a. Num a => PrimType -> a
primByteSize PrimType
ty
      forall {k} (rep :: k) r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op VName
sArray
        String
"warp_exchange"
        PrimType
ty
        (forall d. [d] -> ShapeBase d
Shape [forall v. IsValue v => v -> SubExp
constant (forall a. Num a => a
warpSize :: Int64)])
        VName
localMem
        forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TExp Int64
off' [forall a. Num a => a
warpSize]

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
warpscan, [VName]
warpExchanges)

inBlockScanLookback ::
  KernelConstants ->
  Imp.TExp Int64 ->
  VName ->
  [VName] ->
  Lambda GPUMem ->
  InKernelGen ()
inBlockScanLookback :: KernelConstants
-> TExp Int64
-> VName
-> [VName]
-> Lambda GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
inBlockScanLookback KernelConstants
constants TExp Int64
arrs_full_size VName
flag_arr [VName]
arrs Lambda GPUMem
scan_lam = forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$ do
  TV Any
flg_x <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"flg_x" PrimType
p_int8
  TV Int8
flg_y <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"flg_y" PrimType
p_int8
  let flg_param_x :: Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_x = forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty (forall {k} (t :: k). TV t -> VName
tvVar TV Any
flg_x) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
      flg_param_y :: Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_y = forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_y) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
      flg_y_exp :: TPrimExp Int8 VName
flg_y_exp = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flg_y
      statusP :: TPrimExp Int8 VName
statusP = (TPrimExp Int8 VName
2 :: Imp.TExp Int8)
      statusX :: TPrimExp Int8 VName
statusX = (TPrimExp Int8 VName
0 :: Imp.TExp Int8)

  forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam)

  TV Int32
skip_threads <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"skip_threads" PrimType
int32
  let in_block_thread_active :: TPrimExp Bool VName
in_block_thread_active =
        forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
in_block_id
      actual_params :: [LParam GPUMem]
actual_params = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam
      ([Param (MemInfo SubExp NoUniqueness MemBind)]
x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) =
        forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam GPUMem]
actual_params forall a. Integral a => a -> a -> a
`div` Int
2) [LParam GPUMem]
actual_params
      y_to_x :: ImpM GPUMem KernelEnv KernelOp ()
y_to_x =
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall shape u. TypeBase shape u -> Bool
primType (forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x)) forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y)) []
      y_to_x_flg :: ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg =
        forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall {k} (t :: k). TV t -> VName
tvVar TV Any
flg_x) [] (VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_y)) []

  -- Set initial y values
  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read input for in-block scan" forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ()
readInitial (Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_y forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (VName
flag_arr forall a. a -> [a] -> [a]
: [VName]
arrs)
    -- Since the final result is expected to be in x_params, we may
    -- need to copy it there for the first thread in the block.
    forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
in_block_id forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
      ImpM GPUMem KernelEnv KernelOp ()
y_to_x
      ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier

  let op_to_x :: ImpM GPUMem KernelEnv KernelOp ()
op_to_x = do
        forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
          (TPrimExp Int8 VName
flg_y_exp forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusP forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TPrimExp Int8 VName
flg_y_exp forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusX)
          ( do
              ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg
              ImpM GPUMem KernelEnv KernelOp ()
y_to_x
          )
          (forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam)

  forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"in-block scan (hopefully no barriers needed)" forall a b. (a -> b) -> a -> b
$ do
    TV Int32
skip_threads forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
1

    forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhile (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
block_size) forall a b. (a -> b) -> a -> b
$ do
      forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
in_block_thread_active forall a b. (a -> b) -> a -> b
$ do
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read operands" forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
            (TExp Int64
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
readParam (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads))
            (Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_x forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params)
            (VName
flag_arr forall a. a -> [a] -> [a]
: [VName]
arrs)
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform operation" ImpM GPUMem KernelEnv KernelOp ()
op_to_x

        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write result" forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ forall a b. (a -> b) -> a -> b
$
            forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
              Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
writeResult
              (Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_x forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params)
              (Param (MemInfo SubExp NoUniqueness MemBind)
flg_param_y forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params)
              (VName
flag_arr forall a. a -> [a] -> [a]
: [VName]
arrs)

      TV Int32
skip_threads forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads forall a. Num a => a -> a -> a
* TExp Int32
2
  where
    p_int8 :: PrimType
p_int8 = IntType -> PrimType
IntType IntType
Int8
    block_size :: TExp Int32
block_size = TExp Int32
32
    block_id :: TExp Int32
block_id = TExp Int32
ltid32 forall a. IntegralExp a => a -> a -> a
`quot` TExp Int32
block_size
    in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 forall a. Num a => a -> a -> a
- TExp Int32
block_id forall a. Num a => a -> a -> a
* TExp Int32
block_size
    ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
    ltid :: TExp Int64
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
    gtid :: TExp Int64
gtid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
    array_scan :: Bool
array_scan = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPUMem
scan_lam
    barrier :: ImpM GPUMem KernelEnv KernelOp ()
barrier
      | Bool
array_scan =
          forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
      | Bool
otherwise =
          forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

    readInitial :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ()
readInitial Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
      | forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix TExp Int64
ltid]
      | Bool
otherwise =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix TExp Int64
gtid]
    readParam :: TExp Int64
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
readParam TExp Int64
behind Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
      | forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TExp Int64
ltid forall a. Num a => a -> a -> a
- TExp Int64
behind]
      | Bool
otherwise =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ TExp Int64
gtid forall a. Num a => a -> a -> a
- TExp Int64
behind forall a. Num a => a -> a -> a
+ TExp Int64
arrs_full_size]

    writeResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
writeResult Param (MemInfo SubExp NoUniqueness MemBind)
x Param (MemInfo SubExp NoUniqueness MemBind)
y VName
arr
      | forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x = do
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
arr [forall d. d -> DimIndex d
DimFix TExp Int64
ltid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
      | Bool
otherwise =
          forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []

-- | Compile 'SegScan' instance to host-level code with calls to a
-- single-pass kernel.
compileSegScan ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  SegBinOp GPUMem ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat (MemInfo SubExp NoUniqueness MemBind)
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scanOp KernelBody GPUMem
kbody = do
  let Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes = Pat (MemInfo SubExp NoUniqueness MemBind)
pat
      scanOpNe :: [SubExp]
scanOpNe = forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scanOp
      tys :: [PrimType]
tys = forall a b. (a -> b) -> [a] -> [b]
map (\(Prim PrimType
pt) -> PrimType
pt) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp
      n :: TExp Int64
n = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
      sumT :: Integer
      maxT :: Integer
      sumT :: Integer
sumT = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Integer
bytes PrimType
typ -> Integer
bytes forall a. Num a => a -> a -> a
+ forall a. Num a => PrimType -> a
primByteSize PrimType
typ) Integer
0 [PrimType]
tys
      primByteSize' :: PrimType -> Integer
primByteSize' = forall a. Ord a => a -> a -> a
max Integer
4 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => PrimType -> a
primByteSize
      sumT' :: Integer
sumT' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Integer
bytes PrimType
typ -> Integer
bytes forall a. Num a => a -> a -> a
+ PrimType -> Integer
primByteSize' PrimType
typ) Integer
0 [PrimType]
tys forall a. Integral a => a -> a -> a
`div` Integer
4
      maxT :: Integer
maxT = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum (forall a b. (a -> b) -> [a] -> [b]
map forall a. Num a => PrimType -> a
primByteSize [PrimType]
tys)
      m :: Num a => a
      m :: forall a. Num a => a
m = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
max Integer
1 forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
min Integer
mem_constraint Integer
reg_constraint
      -- TODO: Make these constants dynamic by querying device
      k_reg :: Integer
k_reg = Integer
64
      k_mem :: Integer
k_mem = Integer
95
      mem_constraint :: Integer
mem_constraint = forall a. Ord a => a -> a -> a
max Integer
k_mem Integer
sumT forall a. Integral a => a -> a -> a
`div` Integer
maxT
      reg_constraint :: Integer
reg_constraint = (Integer
k_reg forall a. Num a => a -> a -> a
- Integer
1 forall a. Num a => a -> a -> a
- Integer
sumT') forall a. Integral a => a -> a -> a
`div` (Integer
2 forall a. Num a => a -> a -> a
* Integer
sumT')

      group_size :: Count GroupSize SubExp
group_size = SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
      group_size' :: TExp Int64
group_size' = SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size

  Count NumGroups SubExp
num_groups <-
    forall {k} (u :: k) e. e -> Count u e
Count forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k). TV t -> SubExp
tvSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_groups" (TExp Int64
n forall a. IntegralExp a => a -> a -> a
`divUp` (TExp Int64
group_size' forall a. Num a => a -> a -> a
* forall a. Num a => a
m))
  let num_groups' :: TExp Int64
num_groups' = SubExp -> TExp Int64
pe64 (forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups SubExp
num_groups)

  TExp Int64
num_threads <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"num_threads" forall a b. (a -> b) -> a -> b
$ TExp Int64
num_groups' forall a. Num a => a -> a -> a
* TExp Int64
group_size'

  let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TExp Int64]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
dims
      segmented :: Bool
segmented = forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dims' forall a. Ord a => a -> a -> Bool
> Int
1
      not_segmented_e :: TPrimExp Bool VName
not_segmented_e = if Bool
segmented then forall v. TPrimExp Bool v
false else forall v. TPrimExp Bool v
true
      segment_size :: TExp Int64
segment_size = forall a. [a] -> a
last [TExp Int64]
dims'

      statusX, statusA, statusP :: Num a => a
      statusX :: forall a. Num a => a
statusX = a
0
      statusA :: forall a. Num a => a
statusA = a
1
      statusP :: forall a. Num a => a
statusP = a
2

  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Sequential elements per thread (m):" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (forall a. Num a => a
m :: Imp.TExp Int32)
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory constraint " forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
mem_constraint :: Imp.TExp Int32)
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Register constraint" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
reg_constraint :: Imp.TExp Int32)
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"sumT'" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
sumT' :: Imp.TExp Int32)

  VName
globalId <- forall {k} (rep :: k) r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"id_counter" (String -> Space
Space String
"device") PrimType
int32 forall a b. (a -> b) -> a -> b
$ Int -> ArrayContents
Imp.ArrayZeros Int
1
  VName
statusFlags <- forall {k} (rep :: k) r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"status_flags" PrimType
int8 (forall d. [d] -> ShapeBase d
Shape [forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups SubExp
num_groups]) (String -> Space
Space String
"device")
  ([VName]
aggregateArrays, [VName]
incprefixArrays) <-
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
        (,)
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"aggregates" PrimType
ty (forall d. [d] -> ShapeBase d
Shape [forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups SubExp
num_groups]) (String -> Space
Space String
"device")
          forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} (rep :: k) r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"incprefixes" PrimType
ty (forall d. [d] -> ShapeBase d
Shape [forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups SubExp
num_groups]) (String -> Space
Space String
"device")

  VName -> SubExp -> CallKernelGen ()
sReplicate VName
statusFlags forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusX

  String
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread String
"segscan" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv

    (VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
warpscan, [VName]
exchanges) <-
      Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) (IntType -> Integer -> SubExp
intConst IntType
Int64 forall a. Num a => a
m) [PrimType]
tys

    TV Int64
dynamicId <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"dynamic_id" PrimType
int32
    forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
      (VName
globalIdMem, Space
_, Count Elements (TExp Int64)
globalIdOff) <- forall {k} (rep :: k) r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
globalId [TExp Int64
0]
      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$
        Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace forall a b. (a -> b) -> a -> b
$
          IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd
            IntType
Int32
            (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
dynamicId)
            VName
globalIdMem
            (forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count Elements (TExp Int64)
globalIdOff)
            (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int32
1 :: Imp.TExp Int32))
      forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
sharedId [TExp Int64
0] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
dynamicId) []

    let localBarrier :: KernelOp
localBarrier = Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
        localFence :: KernelOp
localFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
        globalFence :: KernelOp
globalFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal

    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
    forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
dynamicId) [] (VName -> SubExp
Var VName
sharedId) [TExp Int64
0]
    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    TV Int64
blockOff <-
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"blockOff" forall a b. (a -> b) -> a -> b
$
        forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId) forall a. Num a => a -> a -> a
* forall a. Num a => a
m forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
    TExp Int64
sgmIdx <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sgm_idx" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
blockOff forall a. IntegralExp a => a -> a -> a
`mod` TExp Int64
segment_size
    TExp Int32
boundary <-
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"boundary" forall a b. (a -> b) -> a -> b
$
        forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
          forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (forall a. Num a => a
m forall a. Num a => a -> a -> a
* TExp Int64
group_size') (TExp Int64
segment_size forall a. Num a => a -> a -> a
- TExp Int64
sgmIdx)
    TExp Int32
segsize_compact <-
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segsize_compact" forall a b. (a -> b) -> a -> b
$
        forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
          forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (forall a. Num a => a
m forall a. Num a => a -> a -> a
* TExp Int64
group_size') TExp Int64
segment_size
    [VName]
privateArrays <-
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
        forall {k} (rep :: k) r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray
          String
"private"
          PrimType
ty
          (forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 forall a. Num a => a
m])
          ([SubExp] -> PrimType -> Space
ScalarSpace [IntType -> Integer -> SubExp
intConst IntType
Int64 forall a. Num a => a
m] PrimType
ty)

    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Load and map" forall a b. (a -> b) -> a -> b
$
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" forall a. Num a => a
m forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        -- The map's input index
        TExp Int64
phys_tid <-
          forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"phys_tid" forall a b. (a -> b) -> a -> b
$
            forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
blockOff
              forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
              forall a. Num a => a -> a -> a
+ TExp Int64
i forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
        forall {k} (rep :: k) r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims') TExp Int64
phys_tid
        -- Perform the map
        let in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
in_bounds =
              forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$ do
                let ([KernelResult]
all_scan_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem
scanOp]) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody

                -- Write map results to their global memory destinations
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) [KernelResult]
map_res) forall a b. (a -> b) -> a -> b
$ \(PatElem (MemInfo SubExp NoUniqueness MemBind)
dest, KernelResult
src) ->
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
dest) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
src) []

                -- Write to-scan results to private memory.
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
all_scan_res) forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
src) ->
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
src []

            out_of_bounds :: ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds =
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [SubExp]
scanOpNe) forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) ->
                forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
ne []

        forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TExp Int64
phys_tid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
n) ImpM GPUMem KernelEnv KernelOp ()
in_bounds ImpM GPUMem KernelEnv KernelOp ()
out_of_bounds

    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Transpose scan inputs" forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
transposedArrays [VName]
privateArrays) forall a b. (a -> b) -> a -> b
$ \(VName
trans, VName
priv) -> do
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" forall a. Num a => a
m forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
          TExp Int64
sharedIdx <-
            forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sharedIdx" forall a b. (a -> b) -> a -> b
$
              forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
                forall a. Num a => a -> a -> a
+ TExp Int64
i forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
          forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
trans [TExp Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TExp Int64
i]
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" forall a. Num a => a
m forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
          TV Int32
sharedIdx <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall a. Num a => a -> a -> a
* forall a. Num a => a
m forall a. Num a => a -> a -> a
+ TExp Int32
i
          forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
priv [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i] (VName -> SubExp
Var VName
trans) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
sharedIdx]
      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Per thread scan" forall a b. (a -> b) -> a -> b
$ do
      -- We don't need to touch the first element, so only m-1
      -- iterations here.
      TExp Int32
globalIdx <-
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"gidx" forall a b. (a -> b) -> a -> b
$
          (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall a. Num a => a -> a -> a
* forall a. Num a => a
m) forall a. Num a => a -> a -> a
+ TExp Int32
1
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" (forall a. Num a => a
m forall a. Num a => a -> a -> a
- TExp Int64
1) forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        let xs :: [VName]
xs = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scanOp
            ys :: [VName]
ys = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scanOp
        -- determine if start of segment
        TPrimExp Bool VName
new_sgm <-
          if Bool
segmented
            then forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"new_sgm" forall a b. (a -> b) -> a -> b
$ (TExp Int32
globalIdx forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
i forall a. Num a => a -> a -> a
- TExp Int32
boundary) forall a. IntegralExp a => a -> a -> a
`mod` TExp Int32
segsize_compact forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
            else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall v. TPrimExp Bool v
false
        -- skip scan of first element in segment
        forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sUnless TPrimExp Bool VName
new_sgm forall a b. (a -> b) -> a -> b
$ do
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [VName]
ys [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ \(VName
src, (VName
x, VName
y, PrimType
ty)) -> do
            forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
            forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
            forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
            forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i forall a. Num a => a -> a -> a
+ TExp Int64
1]

          forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp) forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp) forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
res) ->
              forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i forall a. Num a => a -> a -> a
+ TExp Int64
1] SubExp
res []

    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Publish results in shared memory" forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prefixArrays [VName]
privateArrays) forall a b. (a -> b) -> a -> b
$ \(VName
dest, VName
src) ->
        forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (VName -> SubExp
Var VName
src) [forall a. Num a => a
m forall a. Num a => a -> a -> a
- TExp Int64
1]
      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    let crossesSegment :: Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
crossesSegment = do
          forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
segmented
          forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
            let from' :: TExp Int32
from' = (TExp Int32
from forall a. Num a => a -> a -> a
+ TExp Int32
1) forall a. Num a => a -> a -> a
* forall a. Num a => a
m forall a. Num a => a -> a -> a
- TExp Int32
1
                to' :: TExp Int32
to' = (TExp Int32
to forall a. Num a => a -> a -> a
+ TExp Int32
1) forall a. Num a => a -> a -> a
* forall a. Num a => a
m forall a. Num a => a -> a -> a
- TExp Int32
1
             in (TExp Int32
to' forall a. Num a => a -> a -> a
- TExp Int32
from') forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32
to' forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact forall a. Num a => a -> a -> a
- TExp Int32
boundary) forall a. IntegralExp a => a -> a -> a
`mod` TExp Int32
segsize_compact

    Lambda GPUMem
scanOp' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scanOp

    [TV Any]
accs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"acc") [PrimType]
tys
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Scan results (with warp scan)" forall a b. (a -> b) -> a -> b
$ do
      Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> TExp Int64
-> TExp Int64
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
        Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
crossesSegment
        TExp Int64
num_threads
        (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
        Lambda GPUMem
scanOp'
        [VName]
prefixArrays

      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

      let firstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread TV Any
acc VName
prefixes =
            forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) forall a. Num a => a -> a -> a
- TExp Int64
1]
          notFirstThread :: TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread TV Any
acc VName
prefixes =
            forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants) forall a. Num a => a -> a -> a
- TExp Int64
1]
      forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
        (forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
firstThread [TV Any]
accs [VName]
prefixArrays)
        (forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM GPUMem KernelEnv KernelOp ()
notFirstThread [TV Any]
accs [VName]
prefixArrays)

      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    [TV Any]
prefixes <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"prefix" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
    TPrimExp Bool VName
blockNewSgm <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"block_new_sgm" forall a b. (a -> b) -> a -> b
$ TExp Int64
sgmIdx forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Perform lookback" forall a b. (a -> b) -> a -> b
$ do
      forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Bool VName
blockNewSgm forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
        forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
accs [VName]
incprefixArrays) forall a b. (a -> b) -> a -> b
$ \(TV Any
acc, VName
incprefixArray) ->
            forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
acc) []
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
        forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusP) []
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, TV Any
acc) ->
          forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc) [] SubExp
ne []
      -- end sWhen

      let warpSize :: TExp Int32
warpSize = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
      forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool VName
blockNewSgm forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
warpSize) forall a b. (a -> b) -> a -> b
$ do
        forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
          forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
            (TPrimExp Bool VName
not_segmented_e forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Int32
boundary forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64
group_size' forall a. Num a => a -> a -> a
* forall a. Num a => a
m))
            ( do
                forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
aggregateArrays [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(VName
aggregateArray, TV Any
acc) ->
                    forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
aggregateArray [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
acc) []
                forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
                forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusA) []
            )
            ( do
                forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(VName
incprefixArray, TV Any
acc) ->
                    forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
acc) []
                forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
                forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusP) []
            )
          forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
warpscan [TExp Int64
0] (VName -> SubExp
Var VName
statusFlags) [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId forall a. Num a => a -> a -> a
- TExp Int64
1]
        -- sWhen
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localFence

        TV Int8
status <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"status" PrimType
int8 :: InKernelGen (TV Int8)
        forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
status) [] (VName -> SubExp
Var VName
warpscan) [TExp Int64
0]

        forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
          (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
status forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusP)
          ( forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [VName]
incprefixArrays) forall a b. (a -> b) -> a -> b
$ \(TV Any
prefix, VName
incprefixArray) ->
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
incprefixArray) [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId forall a. Num a => a -> a -> a
- TExp Int64
1]
          )
          ( do
              TV Int32
readOffset <-
                forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"readOffset" forall a b. (a -> b) -> a -> b
$
                  forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
                    forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants)
              let loopStop :: TExp Int32
loopStop = TExp Int32
warpSize forall a. Num a => a -> a -> a
* (-TExp Int32
1)
                  sameSegment :: TV Int32 -> TPrimExp Bool VName
sameSegment TV Int32
readIdx
                    | Bool
segmented =
                        let startIdx :: TExp Int64
startIdx = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readIdx forall a. Num a => a -> a -> a
+ TExp Int32
1) forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants forall a. Num a => a -> a -> a
* forall a. Num a => a
m forall a. Num a => a -> a -> a
- TExp Int64
1
                         in forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
blockOff forall a. Num a => a -> a -> a
- TExp Int64
startIdx forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
sgmIdx
                    | Bool
otherwise = forall v. TPrimExp Bool v
true
              forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhile (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readOffset forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
loopStop) forall a b. (a -> b) -> a -> b
$ do
                TV Int32
readI <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"read_i" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readOffset forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
                [TV Any]
aggrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
                  forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"aggr" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
                TV Int8
flag <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"flag" (forall a. Num a => a
statusX :: Imp.TExp Int8)
                forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
                  forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                    (TV Int32 -> TPrimExp Bool VName
sameSegment TV Int32
readI)
                    ( do
                        forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
statusFlags) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
                        forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                          (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusP)
                          ( forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
aggrs) forall a b. (a -> b) -> a -> b
$ \(VName
incprefix, TV Any
aggr) ->
                              forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
incprefix) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
                          )
                          ( forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusA) forall a b. (a -> b) -> a -> b
$ do
                              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
aggregateArrays) forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
aggregate) ->
                                forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
aggregate) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
                          )
                    )
                    (forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flag) [] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusP) [])
                -- end sIf
                -- end sWhen

                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
aggrs) forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
aggr) ->
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
aggr) []
                forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
warpscan [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int8
flag) []

                -- execute warp-parallel reduction but only if the last read flag in not STATUS_P
                forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
warpscan) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize forall a. Num a => a -> a -> a
- TExp Int64
1]
                forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TPrimExp Int8 VName
2 :: Imp.TExp Int8)) forall a b. (a -> b) -> a -> b
$ do
                  Lambda GPUMem
lam' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'
                  KernelConstants
-> TExp Int64
-> VName
-> [VName]
-> Lambda GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
inBlockScanLookback
                    KernelConstants
constants
                    TExp Int64
num_threads
                    VName
warpscan
                    [VName]
exchanges
                    Lambda GPUMem
lam'

                -- all threads of the warp read the result of reduction
                forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
warpscan) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize forall a. Num a => a -> a -> a
- TExp Int64
1]
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
exchanges) forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
exchange) ->
                  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
exchange) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize forall a. Num a => a -> a -> a
- TExp Int64
1]
                -- update read offset
                forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                  (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusP)
                  (TV Int32
readOffset forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
loopStop)
                  ( forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. Num a => a
statusA) forall a b. (a -> b) -> a -> b
$ do
                      TV Int32
readOffset forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readOffset forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
zExt32 TExp Int32
warpSize
                  )

                -- update prefix if flag different than STATUS_X:
                forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flag forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall a. Num a => a
statusX :: Imp.TExp Int8)) forall a b. (a -> b) -> a -> b
$ do
                  Lambda GPUMem
lam <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'
                  let ([VName]
xs, [VName]
ys) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
aggrs) forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
aggr) -> forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
aggr)
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
prefixes) forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
prefix) -> forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
prefix)
                  forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam) forall a b. (a -> b) -> a -> b
$
                    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
prefixes [PrimType]
tys forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam) forall a b. (a -> b) -> a -> b
$
                      \(TV Any
prefix, PrimType
ty, SubExp
res) -> TV Any
prefix forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res)
                forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localFence
          )
        -- end sWhile
        -- end sIf
        forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ do
          Lambda GPUMem
scanOp'''' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'
          let xs :: [VName]
xs = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scanOp''''
              ys :: [VName]
ys = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scanOp''''
          forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
boundary forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64
group_size' forall a. Num a => a -> a -> a
* forall a. Num a => a
m)) forall a b. (a -> b) -> a -> b
$ do
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
prefixes) forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
prefix) -> forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
prefix
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
acc) -> forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
acc
            forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp'''') forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp'''') forall a b. (a -> b) -> a -> b
$
                  \(VName
incprefixArray, SubExp
res) -> forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId] SubExp
res []
            forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
globalFence
            forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
statusFlags [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 forall a. Num a => a
statusP) []
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
            forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int64
0] (forall {k} (t :: k). TV t -> SubExp
tvSize TV Any
prefix) []
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
accs [PrimType]
tys [SubExp]
scanOpNe) forall a b. (a -> b) -> a -> b
$ \(TV Any
acc, PrimType
ty, SubExp
ne) ->
            forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
      -- end sWhen
      -- end sWhen

      forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) forall a b. (a -> b) -> a -> b
$ do
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
          forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
exchange) [TExp Int64
0]
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
    -- end sWhen
    -- end sComment

    Lambda GPUMem
scanOp''''' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'
    Lambda GPUMem
scanOp'''''' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scanOp'

    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Distribute results" forall a b. (a -> b) -> a -> b
$ do
      let ([VName]
xs, [VName]
ys) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scanOp'''''
          ([VName]
xs', [VName]
ys') = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scanOp''''''

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [TV Any]
accs) (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [VName]
xs') (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [VName]
ys') [PrimType]
tys) forall a b. (a -> b) -> a -> b
$
        \((TV Any
prefix, TV Any
acc), (VName
x, VName
x'), (VName
y, VName
y'), PrimType
ty) -> do
          forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
          forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
          forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
prefix
          forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
acc

      forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall a. Num a => a -> a -> a
* forall a. Num a => a
m forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
boundary forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TPrimExp Bool VName
blockNewSgm)
        ( forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp'''''') forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [PrimType]
tys forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp'''''') forall a b. (a -> b) -> a -> b
$
              \(VName
x, PrimType
ty, SubExp
res) -> VName
x forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res
        )
        (forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
accs) forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
acc) -> forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> VName
tvVar TV Any
acc) [])
      -- calculate where previous thread stopped, to determine number of
      -- elements left before new segment.
      TExp Int32
stop <-
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"stopping_point" forall a b. (a -> b) -> a -> b
$
          TExp Int32
segsize_compact forall a. Num a => a -> a -> a
- (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall a. Num a => a -> a -> a
* forall a. Num a => a
m forall a. Num a => a -> a -> a
- TExp Int32
1 forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact forall a. Num a => a -> a -> a
- TExp Int32
boundary) forall a. IntegralExp a => a -> a -> a
`rem` TExp Int32
segsize_compact
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" forall a. Num a => a
m forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
stop forall a. Num a => a -> a -> a
- TExp Int32
1) forall a b. (a -> b) -> a -> b
$ do
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [VName]
ys) forall a b. (a -> b) -> a -> b
$ \(VName
src, VName
y) ->
            -- only include prefix for the first segment part per thread
            forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
          forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp''''') forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scanOp''''') forall a b. (a -> b) -> a -> b
$
              \(VName
dest, SubExp
res) ->
                forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
res []

    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Transpose scan output and Write it to global memory in coalesced fashion" forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
transposedArrays [VName]
privateArrays forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) forall a b. (a -> b) -> a -> b
$ \(VName
locmem, VName
priv, VName
dest) -> do
        -- sOp localBarrier
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" forall a. Num a => a
m forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
          TV Int64
sharedIdx <-
            forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" forall a b. (a -> b) -> a -> b
$
              forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall a. Num a => a -> a -> a
* forall a. Num a => a
m) forall a. Num a => a -> a -> a
+ TExp Int64
i
          forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
locmem [forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TExp Int64
i]
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" forall a. Num a => a
m forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
          TExp Int64
flat_idx <-
            forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_idx" forall a b. (a -> b) -> a -> b
$
              forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
blockOff
                forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants forall a. Num a => a -> a -> a
* TExp Int64
i
                forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
          forall {k} (rep :: k) r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [TExp Int64]
dims') TExp Int64
flat_idx
          forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int64
flat_idx forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
n) forall a b. (a -> b) -> a -> b
$ do
            forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
              VName
dest
              (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids)
              (VName -> SubExp
Var VName
locmem)
              [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ TExp Int64
flat_idx forall a. Num a => a -> a -> a
- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
blockOff]
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp KernelOp
localBarrier

    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"If this is the last block, reset the dynamicId" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dynamicId forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
num_groups' forall a. Num a => a -> a -> a
- TExp Int64
1) forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
globalId [TExp Int64
0] (forall v. IsValue v => v -> SubExp
constant (Int32
0 :: Int32)) []