{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Code generation for segmented and non-segmented scans.  Uses a
-- fast single-pass algorithm, but which only works on NVIDIA GPUs and
-- with some constraints on the operator.  We use this when we can.
module Futhark.CodeGen.ImpGen.Kernels.SegScan.SinglePass (compileSegScan) where

import Control.Monad.Except
import Data.List (zip4)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (IntegralExp, divUp, quot)
import Prelude hiding (quot)

xParams, yParams :: SegBinOp KernelsMem -> [LParam KernelsMem]
xParams :: SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan =
  Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp KernelsMem -> LambdaT KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan))
yParams :: SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scan =
  Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp KernelsMem -> LambdaT KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan))

alignTo :: IntegralExp a => a -> a -> a
alignTo :: a -> a -> a
alignTo a
x a
a = (a
x a -> a -> a
forall e. IntegralExp e => e -> e -> e
`divUp` a
a) a -> a -> a
forall a. Num a => a -> a -> a
* a
a

createLocalArrays ::
  Count GroupSize SubExp ->
  SubExp ->
  [PrimType] ->
  InKernelGen (VName, [VName], [VName], VName, VName, [VName])
createLocalArrays :: Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, VName, [VName])
createLocalArrays (Count SubExp
groupSize) SubExp
m [PrimType]
types = do
  let groupSizeE :: TExp Int64
groupSizeE = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
groupSize
      workSize :: TExp Int64
workSize = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groupSizeE
      prefixArraysSize :: TExp Int64
prefixArraysSize =
        (TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
alignTo TExp Int64
acc TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
groupSizeE) TExp Int64
0 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
          (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
      maxTransposedArraySize :: TExp Int64
maxTransposedArraySize =
        (TExp Int64 -> TExp Int64 -> TExp Int64)
-> [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (\PrimType
ty -> TExp Int64
workSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
ty) [PrimType]
types

      warpSize :: Num a => a
      warpSize :: a
warpSize = a
32
      maxWarpExchangeSize :: TExp Int64
maxWarpExchangeSize =
        (TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
alignTo TExp Int64
acc TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
forall a. Num a => a
warpSize) TExp Int64
0 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
          (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
      maxLookbackSize :: TExp Int64
maxLookbackSize = TExp Int64
maxWarpExchangeSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
forall a. Num a => a
warpSize
      size :: Count Bytes (TExp Int64)
size = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
maxLookbackSize TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
prefixArraysSize TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
maxTransposedArraySize

      varTE :: TV Int64 -> TPrimExp Int64 VName
      varTE :: TV Int64 -> TPrimExp Int64 VName
varTE = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (TV Int64 -> VName) -> TV Int64 -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> VName
forall t. TV t -> VName
tvVar

  [TPrimExp Int64 VName]
byteOffsets <-
    (TExp Int64
 -> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TV Int64 -> TPrimExp Int64 VName)
-> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TPrimExp Int64 VName
varTE (ImpM KernelsMem KernelEnv KernelOp (TV Int64)
 -> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"byte_offsets") ([TExp Int64]
 -> ImpM KernelsMem KernelEnv KernelOp [TPrimExp Int64 VName])
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
      (TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TExp Int64
off TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
alignTo TExp Int64
off TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
groupSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tySize) TExp Int64
0 ([TExp Int64] -> [TExp Int64]) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
        (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types

  [TPrimExp Int64 VName]
warpByteOffsets <-
    (TExp Int64
 -> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TV Int64 -> TPrimExp Int64 VName)
-> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TV Int64 -> TPrimExp Int64 VName
varTE (ImpM KernelsMem KernelEnv KernelOp (TV Int64)
 -> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64
-> ImpM KernelsMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"warp_byte_offset") ([TExp Int64]
 -> ImpM KernelsMem KernelEnv KernelOp [TPrimExp Int64 VName])
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
      (TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\TExp Int64
off TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
alignTo TExp Int64
off TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
forall a. Num a => a
warpSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tySize) TExp Int64
forall a. Num a => a
warpSize ([TExp Int64] -> [TExp Int64]) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
        (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types

  String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Allocate reused shared memeory" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  VName
localMem <- String
-> Count Bytes (TExp Int64)
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc String
"local_mem" Count Bytes (TExp Int64)
size (String -> Space
Space String
"local")
  TV Int64
transposeArrayLength <- String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"trans_arr_len" TExp Int64
workSize

  VName
sharedId <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem String
"shared_id" PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1 :: Int32)]) VName
localMem
  VName
sharedReadOffset <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem String
"shared_read_offset" PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1 :: Int32)]) VName
localMem

  [VName]
transposedArrays <-
    [PrimType]
-> (PrimType -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
types ((PrimType -> ImpM KernelsMem KernelEnv KernelOp VName)
 -> ImpM KernelsMem KernelEnv KernelOp [VName])
-> (PrimType -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
      String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem
        String
"local_transpose_arr"
        PrimType
ty
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
transposeArrayLength])
        VName
localMem

  [VName]
prefixArrays <-
    [(TPrimExp Int64 VName, PrimType)]
-> ((TPrimExp Int64 VName, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([TPrimExp Int64 VName]
-> [PrimType] -> [(TPrimExp Int64 VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
byteOffsets [PrimType]
types) (((TPrimExp Int64 VName, PrimType)
  -> ImpM KernelsMem KernelEnv KernelOp VName)
 -> ImpM KernelsMem KernelEnv KernelOp [VName])
-> ((TPrimExp Int64 VName, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \(TPrimExp Int64 VName
off, PrimType
ty) -> do
      let off' :: TPrimExp Int64 VName
off' = TPrimExp Int64 VName
off TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
      String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray
        String
"local_prefix_arr"
        PrimType
ty
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
groupSize])
        (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
localMem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TPrimExp Int64 VName
off' [SubExp -> TPrimExp Int64 VName
pe64 SubExp
groupSize]

  VName
warpscan <- String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem String
"warpscan" PrimType
int8 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warpSize :: Int64)]) VName
localMem
  [VName]
warpExchanges <-
    [(TPrimExp Int64 VName, PrimType)]
-> ((TPrimExp Int64 VName, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([TPrimExp Int64 VName]
-> [PrimType] -> [(TPrimExp Int64 VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
warpByteOffsets [PrimType]
types) (((TPrimExp Int64 VName, PrimType)
  -> ImpM KernelsMem KernelEnv KernelOp VName)
 -> ImpM KernelsMem KernelEnv KernelOp [VName])
-> ((TPrimExp Int64 VName, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \(TPrimExp Int64 VName
off, PrimType
ty) -> do
      let off' :: TPrimExp Int64 VName
off' = TPrimExp Int64 VName
off TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
      String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray
        String
"warp_exchange"
        PrimType
ty
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warpSize :: Int64)])
        (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
localMem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => num -> Shape num -> IxFun num
IxFun.iotaOffset TPrimExp Int64 VName
off' [TPrimExp Int64 VName
forall a. Num a => a
warpSize]

  (VName, [VName], [VName], VName, VName, [VName])
-> InKernelGen (VName, [VName], [VName], VName, VName, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
sharedReadOffset, VName
warpscan, [VName]
warpExchanges)

-- | Compile 'SegScan' instance to host-level code with calls to a
-- single-pass kernel.
compileSegScan ::
  Pattern KernelsMem ->
  SegLevel ->
  SegSpace ->
  SegBinOp KernelsMem ->
  KernelBody KernelsMem ->
  CallKernelGen ()
compileSegScan :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> SegBinOp KernelsMem
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space SegBinOp KernelsMem
scanOp KernelBody KernelsMem
kbody = do
  let Pattern [PatElemT LParamMem]
_ [PatElemT LParamMem]
all_pes = Pattern KernelsMem
PatternT LParamMem
pat
      group_size :: Count GroupSize (TExp Int64)
group_size = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
      n :: TExp Int64
n = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
      m :: Num a => a
      m :: a
m = a
9
      num_groups :: Count NumGroups (TExp Int64)
num_groups = TExp Int64 -> Count NumGroups (TExp Int64)
forall u e. e -> Count u e
Count (TExp Int64
n TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
forall a. Num a => a
m))
      num_threads :: TExp Int64
num_threads = Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size
      (VName
mapIdx, SubExp
_) = [(VName, SubExp)] -> (VName, SubExp)
forall a. [a] -> a
head ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      scanOpNe :: [SubExp]
scanOpNe = SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scanOp
      tys :: [PrimType]
tys = (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map (\(Prim PrimType
pt) -> PrimType
pt) ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType (LambdaT KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> LambdaT KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SegBinOp KernelsMem -> LambdaT KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scanOp
      statusX, statusA, statusP :: Num a => a
      statusX :: a
statusX = a
0
      statusA :: a
statusA = a
1
      statusP :: a
statusP = a
2
      makeStatusUsed :: TV t -> TV t -> TPrimExp t ExpLeaf
makeStatusUsed TV t
flag TV t
used = TV t -> TPrimExp t ExpLeaf
forall t. TV t -> TExp t
tvExp TV t
flag TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.|. (TV t -> TPrimExp t ExpLeaf
forall t. TV t -> TExp t
tvExp TV t
used TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf -> TPrimExp t ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.<<. TPrimExp t ExpLeaf
2)
      unmakeStatusUsed :: TV Int8 -> TV Int8 -> TV Int8 -> InKernelGen ()
      unmakeStatusUsed :: TV Int8
-> TV Int8 -> TV Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
unmakeStatusUsed TV Int8
flagUsed TV Int8
flag TV Int8
used = do
        TV Int8
used TV Int8 -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flagUsed TExp Int8 -> TExp Int8 -> TExp Int8
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.>>. TExp Int8
2
        TV Int8
flag TV Int8 -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flagUsed TExp Int8 -> TExp Int8 -> TExp Int8
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. TExp Int8
3

  -- Allocate the shared memory for output component
  TV Int64
numThreads <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"numThreads" TExp Int64
num_threads
  TV Int64
numGroups <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"numGroups" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups

  VName
globalId <- String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"id_counter" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Int -> ArrayContents
Imp.ArrayZeros Int
1
  VName
statusFlags <- String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray String
"status_flags" PrimType
int8 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
numGroups]) (String -> Space
Space String
"device")
  ([VName]
aggregateArrays, [VName]
incprefixArrays) <-
    ([(VName, VName)] -> ([VName], [VName]))
-> ImpM KernelsMem HostEnv HostOp [(VName, VName)]
-> ImpM KernelsMem HostEnv HostOp ([VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip (ImpM KernelsMem HostEnv HostOp [(VName, VName)]
 -> ImpM KernelsMem HostEnv HostOp ([VName], [VName]))
-> ImpM KernelsMem HostEnv HostOp [(VName, VName)]
-> ImpM KernelsMem HostEnv HostOp ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
      [PrimType]
-> (PrimType -> ImpM KernelsMem HostEnv HostOp (VName, VName))
-> ImpM KernelsMem HostEnv HostOp [(VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys ((PrimType -> ImpM KernelsMem HostEnv HostOp (VName, VName))
 -> ImpM KernelsMem HostEnv HostOp [(VName, VName)])
-> (PrimType -> ImpM KernelsMem HostEnv HostOp (VName, VName))
-> ImpM KernelsMem HostEnv HostOp [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
        (,) (VName -> VName -> (VName, VName))
-> ImpM KernelsMem HostEnv HostOp VName
-> ImpM KernelsMem HostEnv HostOp (VName -> (VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray String
"aggregates" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
numGroups]) (String -> Space
Space String
"device")
          ImpM KernelsMem HostEnv HostOp (VName -> (VName, VName))
-> ImpM KernelsMem HostEnv HostOp VName
-> ImpM KernelsMem HostEnv HostOp (VName, VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray String
"incprefixes" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
numGroups]) (String -> Space
Space String
"device")

  VName -> SubExp -> CallKernelGen ()
sReplicate VName
statusFlags (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusX

  String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> ImpM KernelsMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread String
"segscan" Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size (SegSpace -> VName
segFlat SegSpace
space) (ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

    (VName
sharedId, [VName]
transposedArrays, [VName]
prefixArrays, VName
sharedReadOffset, VName
warpscan, [VName]
exchanges) <-
      Count GroupSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, VName, [VName])
createLocalArrays (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m) [PrimType]
tys

    TV Int64
dynamicId <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"dynamic_id" PrimType
int32
    TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      (VName
globalIdMem, Space
_, Count Elements (TExp Int64)
globalIdOff) <- VName
-> [TExp Int64]
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
globalId [TExp Int64
0]
      KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> ImpM KernelsMem KernelEnv KernelOp ())
-> KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
          IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd
            IntType
Int32
            (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dynamicId)
            VName
globalIdMem
            (TExp Int64 -> Count Elements (TExp Int64)
forall u e. e -> Count u e
Count (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count Elements (TExp Int64)
globalIdOff)
            (TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32
1 :: Imp.TExp Int32))
      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
sharedId [TExp Int64
0] (TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
dynamicId) []

    let localBarrier :: KernelOp
localBarrier = Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
        localFence :: KernelOp
localFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
        globalFence :: KernelOp
globalFence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal

    KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier
    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dynamicId) [] (VName -> SubExp
Var VName
sharedId) [TExp Int64
0]
    KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier

    TV Int64
blockOff <-
      String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"blockOff" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
        TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
forall a. Num a => a
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants

    [VName]
privateArrays <-
      [PrimType]
-> (PrimType -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [PrimType]
tys ((PrimType -> ImpM KernelsMem KernelEnv KernelOp VName)
 -> ImpM KernelsMem KernelEnv KernelOp [VName])
-> (PrimType -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \PrimType
ty ->
        String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray
          String
"private"
          PrimType
ty
          ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m])
          ([SubExp] -> PrimType -> Space
ScalarSpace [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
forall a. Num a => a
m] PrimType
ty)

    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Load and map" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      String
-> TExp Int64
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        -- The map's input index
        VName -> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
mapIdx (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
            TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
        -- Perform the map
        let in_bounds :: ImpM KernelsMem KernelEnv KernelOp ()
in_bounds =
              Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                let ([KernelResult]
all_scan_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem
scanOp]) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody

                -- Write map results to their global memory destinations
                [(PatElemT LParamMem, KernelResult)]
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [KernelResult] -> [(PatElemT LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElemT LParamMem]
all_pes) [KernelResult]
map_res) (((PatElemT LParamMem, KernelResult)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
dest, KernelResult
src) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
dest) [VName -> TExp Int64
Imp.vi64 VName
mapIdx] (KernelResult -> SubExp
kernelResultSubExp KernelResult
src) []

                -- Write to-scan results to private memory.
                [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
all_scan_res) (((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
src) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
src []

            out_of_bounds :: ImpM KernelsMem KernelEnv KernelOp ()
out_of_bounds =
              [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [SubExp]
scanOpNe) (((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
ne []

        TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (VName -> TExp Int64
Imp.vi64 VName
mapIdx TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
n) ImpM KernelsMem KernelEnv KernelOp ()
in_bounds ImpM KernelsMem KernelEnv KernelOp ()
out_of_bounds

    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Transpose scan inputs" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      [(VName, VName)]
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
transposedArrays [VName]
privateArrays) (((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
trans, VName
priv) -> do
        KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier
        String
-> TExp Int64
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
          TExp Int64
sharedIdx <-
            String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"sharedIdx" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
              TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
                TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
trans [TExp Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TExp Int64
i]
        KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier
        String
-> TExp Int32
-> (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int32
forall a. Num a => a
m ((TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
          TV Int32
sharedIdx <- String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"sharedIdx" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
priv [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i] (VName -> SubExp
Var VName
trans) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
sharedIdx]
      KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier

    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Per thread scan" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      -- We don't need to touch the first element, so only m-1
      -- iterations here.
      String
-> TExp Int64
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" (TExp Int64
forall a. Num a => a
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
-TExp Int64
1) ((TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        let xs :: [VName]
xs = (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scanOp
            ys :: [VName]
ys = (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scanOp

        [(VName, (VName, VName, PrimType))]
-> ((VName, (VName, VName, PrimType))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(VName, VName, PrimType)]
-> [(VName, (VName, VName, PrimType))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([(VName, VName, PrimType)] -> [(VName, (VName, VName, PrimType))])
-> [(VName, VName, PrimType)]
-> [(VName, (VName, VName, PrimType))]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [PrimType] -> [(VName, VName, PrimType)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [VName]
ys [PrimType]
tys) (((VName, (VName, VName, PrimType))
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, (VName, VName, PrimType))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, (VName
x, VName
y, PrimType
ty)) -> do
          VName -> PrimType -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
x PrimType
ty
          VName -> PrimType -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
y PrimType
ty
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1]

        Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> BodyT KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT KernelsMem -> BodyT KernelsMem)
-> LambdaT KernelsMem -> BodyT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegBinOp KernelsMem -> LambdaT KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scanOp) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp]) -> BodyT KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT KernelsMem -> BodyT KernelsMem)
-> LambdaT KernelsMem -> BodyT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegBinOp KernelsMem -> LambdaT KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scanOp) (((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
res) ->
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
dest [TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1] SubExp
res []

    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Publish results in shared memory" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      [(VName, VName)]
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prefixArrays [VName]
privateArrays) (((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, VName
src) ->
        VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
dest [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (VName -> SubExp
Var VName
src) [TExp Int64
forall a. Num a => a
m TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
      KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier

    LambdaT KernelsMem
scanOp' <- LambdaT KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (LambdaT KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (LambdaT KernelsMem
 -> ImpM KernelsMem KernelEnv KernelOp (LambdaT KernelsMem))
-> LambdaT KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (LambdaT KernelsMem)
forall a b. (a -> b) -> a -> b
$ SegBinOp KernelsMem -> LambdaT KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scanOp

    [TV Any]
accs <- (PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
-> [PrimType] -> ImpM KernelsMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"acc") [PrimType]
tys
    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Scan results (with warp scan)" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int64
-> TExp Int64
-> LambdaT KernelsMem
-> [VName]
-> ImpM KernelsMem KernelEnv KernelOp ()
groupScan
        Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. Maybe a
Nothing -- TODO
        (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
numThreads)
        (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
        LambdaT KernelsMem
scanOp'
        [VName]
prefixArrays

      KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier

      let firstThread :: TV Any -> VName -> ImpM KernelsMem KernelEnv KernelOp ()
firstThread TV Any
acc VName
prefixes =
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
          notFirstThread :: TV Any -> VName -> ImpM KernelsMem KernelEnv KernelOp ()
notFirstThread TV Any
acc VName
prefixes =
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
      TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
        (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
        ((TV Any -> VName -> ImpM KernelsMem KernelEnv KernelOp ())
-> [TV Any] -> [VName] -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM KernelsMem KernelEnv KernelOp ()
firstThread [TV Any]
accs [VName]
prefixArrays)
        ((TV Any -> VName -> ImpM KernelsMem KernelEnv KernelOp ())
-> [TV Any] -> [VName] -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ TV Any -> VName -> ImpM KernelsMem KernelEnv KernelOp ()
notFirstThread [TV Any]
accs [VName]
prefixArrays)

      KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier

    [TV Any]
prefixes <- [(SubExp, PrimType)]
-> ((SubExp, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
-> ImpM KernelsMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [PrimType] -> [(SubExp, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) (((SubExp, PrimType)
  -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
 -> ImpM KernelsMem KernelEnv KernelOp [TV Any])
-> ((SubExp, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
-> ImpM KernelsMem KernelEnv KernelOp [TV Any]
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
      String -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"prefix" (TExp Any -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Perform lookback" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
        ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
accs) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefixArray, TV Any
acc) ->
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
        KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
globalFence
        ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
        [(SubExp, TV Any)]
-> ((SubExp, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp] -> [TV Any] -> [(SubExp, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [TV Any]
accs) (((SubExp, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((SubExp, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, TV Any
acc) ->
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc) [] SubExp
ne []
      -- end sWhen

      let warpSize :: TExp Int32
warpSize = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
      TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
warpSize) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
        TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
          ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
aggregateArrays [TV Any]
accs) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
aggregateArray, TV Any
acc) ->
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
aggregateArray [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
acc) []
          KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
globalFence
          ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusA) []
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
warpscan [TExp Int64
0] (VName -> SubExp
Var VName
statusFlags) [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
        -- sWhen
        KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localFence

        TV Int8
status <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int8)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"status" PrimType
int8 :: InKernelGen (TV Int8)
        VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
status) [] (VName -> SubExp
Var VName
warpscan) [TExp Int64
0]

        TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
          (TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
status TExp Int8 -> TExp Int8 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusP)
          ( TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                [(TV Any, VName)]
-> ((TV Any, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [VName]
incprefixArrays) (((TV Any, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
prefix, VName
incprefixArray) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
incprefixArray) [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
          )
          ( do
              TV Int32
readOffset <-
                String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"readOffset" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$
                  TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants)
              let loopStop :: TExp Int32
loopStop = TExp Int32
warpSize TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (-TExp Int32
1)
              TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
loopStop) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                TV Int32
readI <- String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"read_i" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
                [TV Any]
aggrs <- [(SubExp, PrimType)]
-> ((SubExp, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
-> ImpM KernelsMem KernelEnv KernelOp [TV Any]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [PrimType] -> [(SubExp, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
scanOpNe [PrimType]
tys) (((SubExp, PrimType)
  -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
 -> ImpM KernelsMem KernelEnv KernelOp [TV Any])
-> ((SubExp, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
-> ImpM KernelsMem KernelEnv KernelOp [TV Any]
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, PrimType
ty) ->
                  String -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"aggr" (TExp Any -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
                TV Int8
flag <- String -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp (TV Int8)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"flag" TExp Int8
forall a. Num a => a
statusX
                TV Int8
used <- String -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp (TV Int8)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"used" (TExp Int8
0 :: Imp.TExp Int8)
                ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int32
0) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag) [] (VName -> SubExp
Var VName
statusFlags) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI]
                    TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
                      (TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag TExp Int8 -> TExp Int8 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusP)
                      ( [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays [TV Any]
aggrs) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
incprefix, TV Any
aggr) ->
                          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
incprefix) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI]
                      )
                      ( TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag TExp Int8 -> TExp Int8 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusA) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                          [(TV Any, VName)]
-> ((TV Any, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
aggregateArrays) (((TV Any, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
aggregate) ->
                            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
aggregate) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readI]
                          TV Int8
used TV Int8 -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- (TExp Int8
1 :: Imp.TExp Int8)
                      )
                -- end sIf
                -- end sWhen
                [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
aggrs) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
aggr) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
exchange [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
aggr) []
                TV Int8
tmp <- String -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp (TV Int8)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"tmp" (TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp (TV Int8))
-> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp (TV Int8)
forall a b. (a -> b) -> a -> b
$ TV Int8 -> TV Int8 -> TExp Int8
forall t. NumExp t => TV t -> TV t -> TPrimExp t ExpLeaf
makeStatusUsed TV Int8
flag TV Int8
used
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
warpscan [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants] (TV Int8 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int8
tmp) []
                KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localFence

                (VName
warpscanMem, Space
warpscanSpace, Count Elements (TExp Int64)
warpscanOff) <-
                  VName
-> [TExp Int64]
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
warpscan [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
                TV Int8
flag TV Int8 -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- Exp -> TExp Int8
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
warpscanMem Count Elements (TExp Int64)
warpscanOff PrimType
int8 Space
warpscanSpace Volatility
Imp.Volatile)
                TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                  -- TODO: This is a single-threaded reduce
                  TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
                    (TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TExp Bool -> TExp Bool) -> TExp Bool -> TExp Bool
forall a b. (a -> b) -> a -> b
$ TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag TExp Int8 -> TExp Int8 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusP)
                    ( do
                        LambdaT KernelsMem
scanOp'' <- LambdaT KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (LambdaT KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda LambdaT KernelsMem
scanOp'
                        let ([VName]
agg1s, [VName]
agg2s) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
scanOp''

                        [(VName, SubExp, PrimType)]
-> ((VName, SubExp, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [PrimType] -> [(VName, SubExp, PrimType)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
agg1s [SubExp]
scanOpNe [PrimType]
tys) (((VName, SubExp, PrimType)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, SubExp, PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
agg1, SubExp
ne, PrimType
ty) ->
                          VName -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
agg1 (TExp Any -> ImpM KernelsMem KernelEnv KernelOp ())
-> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
                        (VName -> PrimType -> ImpM KernelsMem KernelEnv KernelOp ())
-> [VName] -> [PrimType] -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> PrimType -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ [VName]
agg2s [PrimType]
tys

                        TV Int8
flag1 <- String -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp (TV Int8)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"flag1" TExp Int8
forall a. Num a => a
statusX
                        TV Int8
flag2 <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int8)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"flag2" PrimType
int8
                        TV Int8
used1 <- String -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp (TV Int8)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"used1" (TExp Int8
0 :: Imp.TExp Int8)
                        TV Int8
used2 <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int8)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"used2" PrimType
int8
                        String
-> TExp Int32
-> (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int32
warpSize ((TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
                          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Int8 -> VName
forall t. TV t -> VName
tvVar TV Int8
flag2) [] (VName -> SubExp
Var VName
warpscan) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]
                          TV Int8
-> TV Int8 -> TV Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
unmakeStatusUsed TV Int8
flag2 TV Int8
flag2 TV Int8
used2
                          [(VName, VName)]
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
agg2s [VName]
exchanges) (((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
agg2, VName
exchange) ->
                            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
agg2 [] (VName -> SubExp
Var VName
exchange) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]
                          TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
                            (TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TExp Bool -> TExp Bool) -> TExp Bool -> TExp Bool
forall a b. (a -> b) -> a -> b
$ TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag2 TExp Int8 -> TExp Int8 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusA)
                            ( do
                                TV Int8
flag1 TV Int8 -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag2
                                TV Int8
used1 TV Int8 -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used2
                                [(VName, PrimType, VName)]
-> ((VName, PrimType, VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [VName] -> [(VName, PrimType, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
agg1s [PrimType]
tys [VName]
agg2s) (((VName, PrimType, VName)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, PrimType, VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
agg1, PrimType
ty, VName
agg2) ->
                                  VName
agg1 VName -> Exp -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty (VName -> SubExp
Var VName
agg2)
                            )
                            ( do
                                TV Int8
used1 TV Int8 -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used1 TExp Int8 -> TExp Int8 -> TExp Int8
forall a. Num a => a -> a -> a
+ TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used2
                                Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> BodyT KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp'') (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                                  [(VName, PrimType, SubExp)]
-> ((VName, PrimType, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
agg1s [PrimType]
tys ([SubExp] -> [(VName, PrimType, SubExp)])
-> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp]) -> BodyT KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp'') (((VName, PrimType, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, PrimType, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                                    \(VName
agg1, PrimType
ty, SubExp
res) -> VName
agg1 VName -> Exp -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res
                            )
                        TV Int8
flag TV Int8 -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag1
                        TV Int8
used TV Int8 -> TExp Int8 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used1
                        [(TV Any, PrimType, VName)]
-> ((TV Any, PrimType, VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [VName] -> [(TV Any, PrimType, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
aggrs [PrimType]
tys [VName]
agg1s) (((TV Any, PrimType, VName)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, PrimType
ty, VName
agg1) ->
                          TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr VName -> Exp -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty (VName -> SubExp
Var VName
agg1)
                    )
                    -- else
                    ( [(TV Any, VName)]
-> ((TV Any, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [VName] -> [(TV Any, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
aggrs [VName]
exchanges) (((TV Any, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TV Any, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
aggr, VName
exchange) ->
                        VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
aggr) [] (VName -> SubExp
Var VName
exchange) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warpSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
                    )
                  -- end sIf
                  TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
                    (TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
flag TExp Int8 -> TExp Int8 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int8
forall a. Num a => a
statusP)
                    (TV Int32
readOffset TV Int32 -> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int32
loopStop)
                    (TV Int32
readOffset TV Int32 -> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
readOffset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int8 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
zExt32 (TV Int8 -> TExp Int8
forall t. TV t -> TExp t
tvExp TV Int8
used))
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
sharedReadOffset [TExp Int64
0] (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
readOffset) []
                  LambdaT KernelsMem
scanOp''' <- LambdaT KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (LambdaT KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda LambdaT KernelsMem
scanOp'
                  let ([VName]
xs, [VName]
ys) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
scanOp'''
                  [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
aggrs) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
aggr) -> VName -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
x (TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
aggr)
                  [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
prefixes) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
prefix) -> VName -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
y (TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix)
                  Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> BodyT KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp''') (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                    [(TV Any, PrimType, SubExp)]
-> ((TV Any, PrimType, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
prefixes [PrimType]
tys ([SubExp] -> [(TV Any, PrimType, SubExp)])
-> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp]) -> BodyT KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp''') (((TV Any, PrimType, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                      \(TV Any
prefix, PrimType
ty, SubExp
res) -> TV Any
prefix TV Any -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res)
                -- end sWhen
                KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localFence
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
readOffset) [] (VName -> SubExp
Var VName
sharedReadOffset) [TExp Int64
0]
          )
        -- end sWhile
        -- end sIf
        TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
          LambdaT KernelsMem
scanOp'''' <- LambdaT KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (LambdaT KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda LambdaT KernelsMem
scanOp'
          let xs :: [VName]
xs = (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
scanOp''''
              ys :: [VName]
ys = (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
scanOp''''
          [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [TV Any]
prefixes) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
x, TV Any
prefix) -> VName -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
x (TExp Any -> ImpM KernelsMem KernelEnv KernelOp ())
-> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix
          [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [TV Any]
accs) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
y, TV Any
acc) -> VName -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
y (TExp Any -> ImpM KernelsMem KernelEnv KernelOp ())
-> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
acc
          Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> BodyT KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp'''') (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
incprefixArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp]) -> BodyT KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp'''') (((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                \(VName
incprefixArray, SubExp
res) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] SubExp
res []
          KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
globalFence
          ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
statusFlags [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId] (IntType -> Integer -> SubExp
intConst IntType
Int8 Integer
forall a. Num a => a
statusP) []
          [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
exchange [TExp Int64
0] (TV Any -> SubExp
forall t. TV t -> SubExp
tvSize TV Any
prefix) []
          [(TV Any, PrimType, SubExp)]
-> ((TV Any, PrimType, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TV Any] -> [PrimType] -> [SubExp] -> [(TV Any, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TV Any]
accs [PrimType]
tys [SubExp]
scanOpNe) (((TV Any, PrimType, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TV Any, PrimType, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TV Any
acc, PrimType
ty, SubExp
ne) ->
            TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
acc VName -> Exp -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
ne
      -- end sWhen
      -- end sWhen

      TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TExp Bool -> TExp Bool) -> TExp Bool -> TExp Bool
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
        KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier
        [(VName, TV Any)]
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [TV Any] -> [(VName, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
exchanges [TV Any]
prefixes) (((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, TV Any) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
exchange, TV Any
prefix) ->
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
prefix) [] (VName -> SubExp
Var VName
exchange) [TExp Int64
0]
        KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier
    -- end sWhen
    -- end sComment

    LambdaT KernelsMem
scanOp''''' <- LambdaT KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (LambdaT KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda LambdaT KernelsMem
scanOp'
    LambdaT KernelsMem
scanOp'''''' <- LambdaT KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (LambdaT KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda LambdaT KernelsMem
scanOp'

    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Distribute results" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      let ([VName]
xs, [VName]
ys) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
scanOp'''''
          ([VName]
xs', [VName]
ys') = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
scanOp''''''

      [((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)]
-> (((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(TV Any, TV Any)]
-> [(VName, VName)]
-> [(VName, VName)]
-> [PrimType]
-> [((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 ([TV Any] -> [TV Any] -> [(TV Any, TV Any)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TV Any]
prefixes [TV Any]
accs) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [VName]
xs') ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys [VName]
ys') [PrimType]
tys) ((((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (((TV Any, TV Any), (VName, VName), (VName, VName), PrimType)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        \((TV Any
prefix, TV Any
acc), (VName
x, VName
x'), (VName
y, VName
y'), PrimType
ty) -> do
          VName -> PrimType -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
x PrimType
ty
          VName -> PrimType -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
y PrimType
ty
          VName -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
x' (TExp Any -> ImpM KernelsMem KernelEnv KernelOp ())
-> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
prefix
          VName -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
y' (TExp Any -> ImpM KernelsMem KernelEnv KernelOp ())
-> TExp Any -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV Any -> TExp Any
forall t. TV t -> TExp t
tvExp TV Any
acc

      Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> BodyT KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp'''''') (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        [(VName, PrimType, SubExp)]
-> ((VName, PrimType, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [PrimType] -> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
xs [PrimType]
tys ([SubExp] -> [(VName, PrimType, SubExp)])
-> [SubExp] -> [(VName, PrimType, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp]) -> BodyT KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp'''''') (((VName, PrimType, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, PrimType, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          \(VName
x, PrimType
ty, SubExp
res) -> VName
x VName -> Exp -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
ty SubExp
res

      String
-> TExp Int64
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
        [(VName, VName)]
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays [VName]
ys) (((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, VName
y) ->
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i]

        Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> BodyT KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp''''') (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
privateArrays ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp]) -> BodyT KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT KernelsMem
scanOp''''') (((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            \(VName
dest, SubExp
res) ->
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
res []

    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Transpose scan output" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
      [(VName, VName)]
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
transposedArrays [VName]
privateArrays) (((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
trans, VName
priv) -> do
        KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier
        String
-> TExp Int64
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
          TV Int64
sharedIdx <-
            String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"sharedIdx" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
              TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
forall a. Num a => a
m) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
trans [TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
sharedIdx] (VName -> SubExp
Var VName
priv) [TExp Int64
i]
        KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier
        String
-> TExp Int64
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
          TV Int32
sharedIdx <-
            String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"sharedIdx" (TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32))
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall a b. (a -> b) -> a -> b
$
              KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
                TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
i)
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
priv [TExp Int64
i] (VName -> SubExp
Var VName
trans) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
sharedIdx]
      KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp KernelOp
localBarrier

    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Write block scan results to global memory" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      [(VName, VName)]
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((PatElemT LParamMem -> VName) -> [PatElemT LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT LParamMem]
all_pes) [VName]
privateArrays) (((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, VName
src) ->
        String
-> TExp Int64
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
forall a. Num a => a
m ((TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
          VName -> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
mapIdx (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ())
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
blockOff TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
i
              TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
          TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (VName -> TExp Int64
Imp.vi64 VName
mapIdx TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
n) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
dest [VName -> TExp Int64
Imp.vi64 VName
mapIdx] (VName -> SubExp
Var VName
src) [TExp Int64
i]

    String
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"If this is the last block, reset the dynamicId" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
dynamicId TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
globalId [TExp Int64
0] (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0 :: Int32)) []