{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
-- | Code generation for segmented and non-segmented scans.  Uses a
-- fairly inefficient two-pass algorithm.
module Futhark.CodeGen.ImpGen.Kernels.SegScan
  ( compileSegScan )
  where

import Control.Monad.Except
import Control.Monad.State
import Data.Maybe
import Data.List (delete, find, foldl', zip4)

import Prelude hiding (quot, rem)

import Futhark.Transform.Rename
import Futhark.IR.KernelsMem
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Futhark.Util (takeLast)

-- Aggressively try to reuse memory for different SegBinOps, because
-- we will run them sequentially after another.
makeLocalArrays :: Count GroupSize SubExp -> SubExp -> [SegBinOp KernelsMem]
                -> InKernelGen [[VName]]
makeLocalArrays :: Count GroupSize SubExp
-> SubExp -> [SegBinOp KernelsMem] -> InKernelGen [[VName]]
makeLocalArrays (Count SubExp
group_size) SubExp
num_threads [SegBinOp KernelsMem]
scans = do
  ([[VName]]
arrs, [([Count Bytes (PrimExp ExpLeaf)], VName)]
mems_and_sizes) <- StateT
  [([Count Bytes (PrimExp ExpLeaf)], VName)]
  (ImpM KernelsMem KernelEnv KernelOp)
  [[VName]]
-> [([Count Bytes (PrimExp ExpLeaf)], VName)]
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     ([[VName]], [([Count Bytes (PrimExp ExpLeaf)], VName)])
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((SegBinOp KernelsMem
 -> StateT
      [([Count Bytes (PrimExp ExpLeaf)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      [VName])
-> [SegBinOp KernelsMem]
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp KernelsMem
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [VName]
onScan [SegBinOp KernelsMem]
scans) [([Count Bytes (PrimExp ExpLeaf)], VName)]
forall a. Monoid a => a
mempty
  let maxSize :: [Count u (PrimExp ExpLeaf)] -> Count Bytes (PrimExp ExpLeaf)
maxSize [Count u (PrimExp ExpLeaf)]
sizes =
        PrimExp ExpLeaf -> Count Bytes (PrimExp ExpLeaf)
Imp.bytes (PrimExp ExpLeaf -> Count Bytes (PrimExp ExpLeaf))
-> PrimExp ExpLeaf -> Count Bytes (PrimExp ExpLeaf)
forall a b. (a -> b) -> a -> b
$ (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> PrimExp ExpLeaf -> [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (BinOp -> PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32)) PrimExp ExpLeaf
1 ([PrimExp ExpLeaf] -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$
        (Count u (PrimExp ExpLeaf) -> PrimExp ExpLeaf)
-> [Count u (PrimExp ExpLeaf)] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map Count u (PrimExp ExpLeaf) -> PrimExp ExpLeaf
forall u e. Count u e -> e
Imp.unCount [Count u (PrimExp ExpLeaf)]
sizes
  [([Count Bytes (PrimExp ExpLeaf)], VName)]
-> (([Count Bytes (PrimExp ExpLeaf)], VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([Count Bytes (PrimExp ExpLeaf)], VName)]
mems_and_sizes ((([Count Bytes (PrimExp ExpLeaf)], VName)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (([Count Bytes (PrimExp ExpLeaf)], VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \([Count Bytes (PrimExp ExpLeaf)]
sizes, VName
mem) ->
    VName
-> Count Bytes (PrimExp ExpLeaf)
-> Space
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> Count Bytes (PrimExp ExpLeaf) -> Space -> ImpM lore r op ()
sAlloc_ VName
mem ([Count Bytes (PrimExp ExpLeaf)] -> Count Bytes (PrimExp ExpLeaf)
forall u.
[Count u (PrimExp ExpLeaf)] -> Count Bytes (PrimExp ExpLeaf)
maxSize [Count Bytes (PrimExp ExpLeaf)]
sizes) (SpaceId -> Space
Space SpaceId
"local")
  [[VName]] -> InKernelGen [[VName]]
forall (m :: * -> *) a. Monad m => a -> m a
return [[VName]]
arrs

  where onScan :: SegBinOp KernelsMem
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [VName]
onScan (SegBinOp Commutativity
_ Lambda KernelsMem
scan_op [SubExp]
nes Shape
_) = do
          let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
_scan_y_params) =
                Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
scan_op
          ([VName]
arrs, [[([Count Bytes (PrimExp ExpLeaf)], VName)]]
used_mems) <- ([(VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])]
 -> ([VName], [[([Count Bytes (PrimExp ExpLeaf)], VName)]]))
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])]
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     ([VName], [[([Count Bytes (PrimExp ExpLeaf)], VName)]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])]
-> ([VName], [[([Count Bytes (PrimExp ExpLeaf)], VName)]])
forall a b. [(a, b)] -> ([a], [b])
unzip (StateT
   [([Count Bytes (PrimExp ExpLeaf)], VName)]
   (ImpM KernelsMem KernelEnv KernelOp)
   [(VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])]
 -> StateT
      [([Count Bytes (PrimExp ExpLeaf)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      ([VName], [[([Count Bytes (PrimExp ExpLeaf)], VName)]]))
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])]
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     ([VName], [[([Count Bytes (PrimExp ExpLeaf)], VName)]])
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> StateT
         [([Count Bytes (PrimExp ExpLeaf)], VName)]
         (ImpM KernelsMem KernelEnv KernelOp)
         (VName, [([Count Bytes (PrimExp ExpLeaf)], VName)]))
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params ((Param (MemInfo SubExp NoUniqueness MemBind)
  -> StateT
       [([Count Bytes (PrimExp ExpLeaf)], VName)]
       (ImpM KernelsMem KernelEnv KernelOp)
       (VName, [([Count Bytes (PrimExp ExpLeaf)], VName)]))
 -> StateT
      [([Count Bytes (PrimExp ExpLeaf)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      [(VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])])
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> StateT
         [([Count Bytes (PrimExp ExpLeaf)], VName)]
         (ImpM KernelsMem KernelEnv KernelOp)
         (VName, [([Count Bytes (PrimExp ExpLeaf)], VName)]))
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])]
forall a b. (a -> b) -> a -> b
$ \Param (MemInfo SubExp NoUniqueness MemBind)
p ->
            case Param (MemInfo SubExp NoUniqueness MemBind)
-> MemInfo SubExp NoUniqueness MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp NoUniqueness MemBind)
p of
              MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
                let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
                VName
arr <- ImpM KernelsMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM KernelsMem KernelEnv KernelOp VName
 -> StateT
      [([Count Bytes (PrimExp ExpLeaf)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      VName)
-> ImpM KernelsMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     VName
forall a b. (a -> b) -> a -> b
$ SpaceId
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
SpaceId -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray SpaceId
"scan_arr" PrimType
pt Shape
shape' (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                  VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
                (VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     (VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, [])
              MemInfo SubExp NoUniqueness MemBind
_ -> do
                let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p
                    shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
                ([Count Bytes (PrimExp ExpLeaf)]
sizes, VName
mem') <- PrimType
-> Shape
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     ([Count Bytes (PrimExp ExpLeaf)], VName)
forall (t :: (* -> *) -> * -> *) lore r op.
(MonadState
   [([Count Bytes (PrimExp ExpLeaf)], VName)] (t (ImpM lore r op)),
 MonadTrans t) =>
PrimType
-> Shape
-> t (ImpM lore r op) ([Count Bytes (PrimExp ExpLeaf)], VName)
getMem PrimType
pt Shape
shape
                VName
arr <- ImpM KernelsMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM KernelsMem KernelEnv KernelOp VName
 -> StateT
      [([Count Bytes (PrimExp ExpLeaf)], VName)]
      (ImpM KernelsMem KernelEnv KernelOp)
      VName)
-> ImpM KernelsMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     VName
forall a b. (a -> b) -> a -> b
$ SpaceId
-> PrimType
-> Shape
-> VName
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
SpaceId -> PrimType -> Shape -> VName -> ImpM lore r op VName
sArrayInMem SpaceId
"scan_arr" PrimType
pt Shape
shape VName
mem'
                (VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     (VName, [([Count Bytes (PrimExp ExpLeaf)], VName)])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, [([Count Bytes (PrimExp ExpLeaf)]
sizes, VName
mem')])
          ([([Count Bytes (PrimExp ExpLeaf)], VName)]
 -> [([Count Bytes (PrimExp ExpLeaf)], VName)])
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([([Count Bytes (PrimExp ExpLeaf)], VName)]
-> [([Count Bytes (PrimExp ExpLeaf)], VName)]
-> [([Count Bytes (PrimExp ExpLeaf)], VName)]
forall a. Semigroup a => a -> a -> a
<>[[([Count Bytes (PrimExp ExpLeaf)], VName)]]
-> [([Count Bytes (PrimExp ExpLeaf)], VName)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[([Count Bytes (PrimExp ExpLeaf)], VName)]]
used_mems)
          [VName]
-> StateT
     [([Count Bytes (PrimExp ExpLeaf)], VName)]
     (ImpM KernelsMem KernelEnv KernelOp)
     [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
arrs

        getMem :: PrimType
-> Shape
-> t (ImpM lore r op) ([Count Bytes (PrimExp ExpLeaf)], VName)
getMem PrimType
pt Shape
shape = do
          let size :: Count Bytes (PrimExp ExpLeaf)
size = TypeBase Shape NoUniqueness -> Count Bytes (PrimExp ExpLeaf)
typeSize (TypeBase Shape NoUniqueness -> Count Bytes (PrimExp ExpLeaf))
-> TypeBase Shape NoUniqueness -> Count Bytes (PrimExp ExpLeaf)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness
          [([Count Bytes (PrimExp ExpLeaf)], VName)]
mems <- t (ImpM lore r op) [([Count Bytes (PrimExp ExpLeaf)], VName)]
forall s (m :: * -> *). MonadState s m => m s
get
          case ((([Count Bytes (PrimExp ExpLeaf)], VName) -> Bool)
-> [([Count Bytes (PrimExp ExpLeaf)], VName)]
-> Maybe ([Count Bytes (PrimExp ExpLeaf)], VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Count Bytes (PrimExp ExpLeaf)
size Count Bytes (PrimExp ExpLeaf)
-> [Count Bytes (PrimExp ExpLeaf)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem`) ([Count Bytes (PrimExp ExpLeaf)] -> Bool)
-> (([Count Bytes (PrimExp ExpLeaf)], VName)
    -> [Count Bytes (PrimExp ExpLeaf)])
-> ([Count Bytes (PrimExp ExpLeaf)], VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Count Bytes (PrimExp ExpLeaf)], VName)
-> [Count Bytes (PrimExp ExpLeaf)]
forall a b. (a, b) -> a
fst) [([Count Bytes (PrimExp ExpLeaf)], VName)]
mems, [([Count Bytes (PrimExp ExpLeaf)], VName)]
mems) of
            (Just ([Count Bytes (PrimExp ExpLeaf)], VName)
mem, [([Count Bytes (PrimExp ExpLeaf)], VName)]
_) -> do
              ([([Count Bytes (PrimExp ExpLeaf)], VName)]
 -> [([Count Bytes (PrimExp ExpLeaf)], VName)])
-> t (ImpM lore r op) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (([([Count Bytes (PrimExp ExpLeaf)], VName)]
  -> [([Count Bytes (PrimExp ExpLeaf)], VName)])
 -> t (ImpM lore r op) ())
-> ([([Count Bytes (PrimExp ExpLeaf)], VName)]
    -> [([Count Bytes (PrimExp ExpLeaf)], VName)])
-> t (ImpM lore r op) ()
forall a b. (a -> b) -> a -> b
$ ([Count Bytes (PrimExp ExpLeaf)], VName)
-> [([Count Bytes (PrimExp ExpLeaf)], VName)]
-> [([Count Bytes (PrimExp ExpLeaf)], VName)]
forall a. Eq a => a -> [a] -> [a]
delete ([Count Bytes (PrimExp ExpLeaf)], VName)
mem
              ([Count Bytes (PrimExp ExpLeaf)], VName)
-> t (ImpM lore r op) ([Count Bytes (PrimExp ExpLeaf)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Count Bytes (PrimExp ExpLeaf)], VName)
mem
            (Maybe ([Count Bytes (PrimExp ExpLeaf)], VName)
Nothing, ([Count Bytes (PrimExp ExpLeaf)]
size', VName
mem) : [([Count Bytes (PrimExp ExpLeaf)], VName)]
mems') -> do
              [([Count Bytes (PrimExp ExpLeaf)], VName)] -> t (ImpM lore r op) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [([Count Bytes (PrimExp ExpLeaf)], VName)]
mems'
              ([Count Bytes (PrimExp ExpLeaf)], VName)
-> t (ImpM lore r op) ([Count Bytes (PrimExp ExpLeaf)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Count Bytes (PrimExp ExpLeaf)
size Count Bytes (PrimExp ExpLeaf)
-> [Count Bytes (PrimExp ExpLeaf)]
-> [Count Bytes (PrimExp ExpLeaf)]
forall a. a -> [a] -> [a]
: [Count Bytes (PrimExp ExpLeaf)]
size', VName
mem)
            (Maybe ([Count Bytes (PrimExp ExpLeaf)], VName)
Nothing, []) -> do
              VName
mem <- ImpM lore r op VName -> t (ImpM lore r op) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM lore r op VName -> t (ImpM lore r op) VName)
-> ImpM lore r op VName -> t (ImpM lore r op) VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space -> ImpM lore r op VName
forall lore r op. SpaceId -> Space -> ImpM lore r op VName
sDeclareMem SpaceId
"scan_arr_mem" (Space -> ImpM lore r op VName) -> Space -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space
Space SpaceId
"local"
              ([Count Bytes (PrimExp ExpLeaf)], VName)
-> t (ImpM lore r op) ([Count Bytes (PrimExp ExpLeaf)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Count Bytes (PrimExp ExpLeaf)
size], VName
mem)

type CrossesSegment = Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)

localArrayIndex :: KernelConstants -> Type -> Imp.Exp
localArrayIndex :: KernelConstants -> TypeBase Shape NoUniqueness -> PrimExp ExpLeaf
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t =
  if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t
  then KernelConstants -> PrimExp ExpLeaf
kernelLocalThreadId KernelConstants
constants
  else KernelConstants -> PrimExp ExpLeaf
kernelGlobalThreadId KernelConstants
constants

barrierFor :: Lambda KernelsMem -> (Bool, Imp.Fence, InKernelGen ())
barrierFor :: Lambda KernelsMem
-> (Bool, Fence, ImpM KernelsMem KernelEnv KernelOp ())
barrierFor Lambda KernelsMem
scan_op = (Bool
array_scan, Fence
fence, KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> ImpM KernelsMem KernelEnv KernelOp ())
-> KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence)
  where array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda KernelsMem
scan_op
        fence :: Fence
fence | Bool
array_scan = Fence
Imp.FenceGlobal
              | Bool
otherwise = Fence
Imp.FenceLocal

xParams, yParams :: SegBinOp KernelsMem -> [LParam KernelsMem]
xParams :: SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan =
  Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan))
yParams :: SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scan =
  Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan))

writeToScanValues :: [VName]
                  -> ([PatElem KernelsMem], SegBinOp KernelsMem, [KernelResult])
                  -> InKernelGen ()
writeToScanValues :: [VName]
-> ([PatElem KernelsMem], SegBinOp KernelsMem, [KernelResult])
-> ImpM KernelsMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids ([PatElem KernelsMem]
pes, SegBinOp KernelsMem
scan, [KernelResult]
scan_res)
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp KernelsMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
      [(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes [KernelResult]
scan_res) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, KernelResult
res) ->
      VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids)
      (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
  | Bool
otherwise =
      [(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scan) [KernelResult]
scan_res) (((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, KernelResult
res) ->
      VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

readToScanValues :: [Imp.Exp] -> [PatElem KernelsMem] -> SegBinOp KernelsMem
                 -> InKernelGen ()
readToScanValues :: [PrimExp ExpLeaf]
-> [PatElem KernelsMem]
-> SegBinOp KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
readToScanValues [PrimExp ExpLeaf]
is [PatElem KernelsMem]
pes SegBinOp KernelsMem
scan
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp KernelsMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
      [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scan) [PatElem KernelsMem]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
      VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe)) [PrimExp ExpLeaf]
is
  | Bool
otherwise =
      () -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

readCarries :: Imp.Exp -> [Imp.Exp] -> [Imp.Exp]
            -> [PatElem KernelsMem]
            -> SegBinOp KernelsMem
            -> InKernelGen ()
readCarries :: PrimExp ExpLeaf
-> [PrimExp ExpLeaf]
-> [PrimExp ExpLeaf]
-> [PatElem KernelsMem]
-> SegBinOp KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
readCarries PrimExp ExpLeaf
chunk_offset [PrimExp ExpLeaf]
dims' [PrimExp ExpLeaf]
vec_is [PatElem KernelsMem]
pes SegBinOp KernelsMem
scan
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp KernelsMem -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp KernelsMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
      PrimExp ExpLeaf
ltid <- KernelConstants -> PrimExp ExpLeaf
kernelLocalThreadId (KernelConstants -> PrimExp ExpLeaf)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> PrimExp ExpLeaf
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> PrimExp ExpLeaf)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
      -- We may have to reload the carries from the output of the
      -- previous chunk.
      PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (PrimExp ExpLeaf
chunk_offset PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. PrimExp ExpLeaf
0 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. PrimExp ExpLeaf
ltid PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. PrimExp ExpLeaf
0)
        (do let is :: [PrimExp ExpLeaf]
is = [PrimExp ExpLeaf] -> PrimExp ExpLeaf -> [PrimExp ExpLeaf]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [PrimExp ExpLeaf]
dims' (PrimExp ExpLeaf -> [PrimExp ExpLeaf])
-> PrimExp ExpLeaf -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> a -> b
$ PrimExp ExpLeaf
chunk_offset PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
- PrimExp ExpLeaf
1
            [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan) [PatElem KernelsMem]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
              VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe)) ([PrimExp ExpLeaf]
is[PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a. [a] -> [a] -> [a]
++[PrimExp ExpLeaf]
vec_is))
        ([(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan) (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
            VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne [])
  | Bool
otherwise =
      () -> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Produce partially scanned intervals; one per workgroup.
scanStage1 :: Pattern KernelsMem
           -> Count NumGroups SubExp -> Count GroupSize SubExp -> SegSpace
           -> [SegBinOp KernelsMem]
           -> KernelBody KernelsMem
           -> CallKernelGen (VName, Imp.Exp, CrossesSegment)
scanStage1 :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen (VName, PrimExp ExpLeaf, CrossesSegment)
scanStage1 (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody = do
  Count NumGroups (PrimExp ExpLeaf)
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf))
-> Count NumGroups SubExp
-> ImpM
     KernelsMem HostEnv HostOp (Count NumGroups (PrimExp ExpLeaf))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp Count NumGroups SubExp
num_groups
  Count GroupSize (PrimExp ExpLeaf)
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf))
-> Count GroupSize SubExp
-> ImpM
     KernelsMem HostEnv HostOp (Count GroupSize (PrimExp ExpLeaf))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp Count GroupSize SubExp
group_size
  VName
num_threads <- SpaceId -> PrimExp ExpLeaf -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op VName
dPrimV SpaceId
"num_threads" (PrimExp ExpLeaf -> ImpM KernelsMem HostEnv HostOp VName)
-> PrimExp ExpLeaf -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
                 Count NumGroups (PrimExp ExpLeaf) -> PrimExp ExpLeaf
forall u e. Count u e -> e
unCount Count NumGroups (PrimExp ExpLeaf)
num_groups' PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (PrimExp ExpLeaf) -> PrimExp ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (PrimExp ExpLeaf)
group_size'

  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [PrimExp ExpLeaf]
dims' <- (SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf))
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [PrimExp ExpLeaf]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp [SubExp]
dims
  let num_elements :: PrimExp ExpLeaf
num_elements = [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [PrimExp ExpLeaf]
dims'
      elems_per_thread :: PrimExp ExpLeaf
elems_per_thread = PrimExp ExpLeaf
num_elements PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` VName -> PrimExp ExpLeaf
Imp.vi32 VName
num_threads
      elems_per_group :: PrimExp ExpLeaf
elems_per_group = Count GroupSize (PrimExp ExpLeaf) -> PrimExp ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (PrimExp ExpLeaf)
group_size' PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* PrimExp ExpLeaf
elems_per_thread

  let crossesSegment :: CrossesSegment
crossesSegment =
        case [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a. [a] -> [a]
reverse [PrimExp ExpLeaf]
dims' of
          PrimExp ExpLeaf
segment_size : PrimExp ExpLeaf
_ : [PrimExp ExpLeaf]
_ -> (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> CrossesSegment
forall a. a -> Maybe a
Just ((PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
 -> CrossesSegment)
-> (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \PrimExp ExpLeaf
from PrimExp ExpLeaf
to ->
            (PrimExp ExpLeaf
toPrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
-PrimExp ExpLeaf
from) PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (PrimExp ExpLeaf
to PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall e. IntegralExp e => e -> e -> e
`rem` PrimExp ExpLeaf
segment_size)
          [PrimExp ExpLeaf]
_ -> CrossesSegment
forall a. Maybe a
Nothing

  SpaceId
-> Count NumGroups (PrimExp ExpLeaf)
-> Count GroupSize (PrimExp ExpLeaf)
-> VName
-> ImpM KernelsMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage1" Count NumGroups (PrimExp ExpLeaf)
num_groups' Count GroupSize (PrimExp ExpLeaf)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [[VName]]
all_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp KernelsMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
num_threads) [SegBinOp KernelsMem]
scans

    -- The variables from scan_op will be used for the carry and such
    -- in the big chunking loop.
    [SegBinOp KernelsMem]
-> (SegBinOp KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOp KernelsMem]
scans ((SegBinOp KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (SegBinOp KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \SegBinOp KernelsMem
scan -> do
      Maybe (Exp KernelsMem)
-> Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp KernelsMem)
forall a. Maybe a
Nothing (Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope KernelsMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope KernelsMem)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (Lambda KernelsMem -> [LParam KernelsMem])
-> Lambda KernelsMem -> [LParam KernelsMem]
forall a b. (a -> b) -> a -> b
$ SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan
      [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan) (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
        VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []

    SpaceId
-> PrimExp ExpLeaf
-> (PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
SpaceId
-> PrimExp ExpLeaf
-> (PrimExp ExpLeaf -> ImpM lore r op ())
-> ImpM lore r op ()
sFor SpaceId
"j" PrimExp ExpLeaf
elems_per_thread ((PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \PrimExp ExpLeaf
j -> do
      VName
chunk_offset <- SpaceId
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op VName
dPrimV SpaceId
"chunk_offset" (PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName)
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                      KernelConstants -> PrimExp ExpLeaf
kernelGroupSize KernelConstants
constants PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* PrimExp ExpLeaf
j PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+
                      KernelConstants -> PrimExp ExpLeaf
kernelGroupId KernelConstants
constants PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* PrimExp ExpLeaf
elems_per_group
      VName
flat_idx <- SpaceId
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op VName
dPrimV SpaceId
"flat_idx" (PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName)
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                  VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
chunk_offset PrimType
int32 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+ KernelConstants -> PrimExp ExpLeaf
kernelLocalThreadId KernelConstants
constants
      -- Construct segment indices.
      (VName -> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ())
-> [VName]
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> PrimExp ExpLeaf -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
-> [PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [PrimExp ExpLeaf] -> PrimExp ExpLeaf -> [PrimExp ExpLeaf]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [PrimExp ExpLeaf]
dims' (PrimExp ExpLeaf -> [PrimExp ExpLeaf])
-> PrimExp ExpLeaf -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp ExpLeaf
Imp.vi32 VName
flat_idx

      let per_scan_pes :: [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes = [SegBinOp KernelsMem]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp KernelsMem]
scans [PatElem KernelsMem]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
all_pes

          in_bounds :: PrimExp ExpLeaf
in_bounds =
            (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) ([PrimExp ExpLeaf] -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.<.) ((VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids) [PrimExp ExpLeaf]
dims'

          when_in_bounds :: ImpM KernelsMem KernelEnv KernelOp ()
when_in_bounds = Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
            let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
                  Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
scans) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
                per_scan_res :: [[KernelResult]]
per_scan_res =
                  [SegBinOp KernelsMem] -> [KernelResult] -> [[KernelResult]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp KernelsMem]
scans [KernelResult]
all_scan_res

            SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"write to-scan values to parameters" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              (([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
  SegBinOp KernelsMem, [KernelResult])
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> [([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp KernelsMem, [KernelResult])]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([VName]
-> ([PatElem KernelsMem], SegBinOp KernelsMem, [KernelResult])
-> ImpM KernelsMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids) ([([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
   SegBinOp KernelsMem, [KernelResult])]
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> [([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp KernelsMem, [KernelResult])]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
-> [SegBinOp KernelsMem]
-> [[KernelResult]]
-> [([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp KernelsMem, [KernelResult])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes [SegBinOp KernelsMem]
scans [[KernelResult]]
per_scan_res

            SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"write mapped values results to global memory" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              [(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem KernelsMem]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
all_pes) [KernelResult]
map_res) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, KernelResult
se) ->
              VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids)
              (KernelResult -> SubExp
kernelResultSubExp KernelResult
se) []

      SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"threads in bounds read input" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen PrimExp ExpLeaf
in_bounds ImpM KernelsMem KernelEnv KernelOp ()
when_in_bounds

      [([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
  SegBinOp KernelsMem, [VName])]
-> (([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp KernelsMem, [VName])
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
-> [SegBinOp KernelsMem]
-> [[VName]]
-> [([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp KernelsMem, [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes [SegBinOp KernelsMem]
scans [[VName]]
all_local_arrs) ((([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
   SegBinOp KernelsMem, [VName])
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp KernelsMem, [VName])
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        \([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes, scan :: SegBinOp KernelsMem
scan@(SegBinOp Commutativity
_ Lambda KernelsMem
scan_op [SubExp]
nes Shape
vec_shape), [VName]
local_arrs) ->
        SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"do one intra-group scan operation" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
        let rets :: [TypeBase Shape NoUniqueness]
rets = Lambda KernelsMem -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda KernelsMem
scan_op
            scan_x_params :: [LParam KernelsMem]
scan_x_params = SegBinOp KernelsMem -> [LParam KernelsMem]
xParams SegBinOp KernelsMem
scan
            (Bool
array_scan, Fence
fence, ImpM KernelsMem KernelEnv KernelOp ()
barrier) = Lambda KernelsMem
-> (Bool, Fence, ImpM KernelsMem KernelEnv KernelOp ())
barrierFor Lambda KernelsMem
scan_op

        Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM KernelsMem KernelEnv KernelOp ()
barrier

        Shape
-> ([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Shape
-> ([PrimExp ExpLeaf] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
vec_shape (([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[PrimExp ExpLeaf]
vec_is -> do
          SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"maybe restore some to-scan values to parameters, or read neutral" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf PrimExp ExpLeaf
in_bounds
            (do [PrimExp ExpLeaf]
-> [PatElem KernelsMem]
-> SegBinOp KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
readToScanValues ((VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids[PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a. [a] -> [a] -> [a]
++[PrimExp ExpLeaf]
vec_is) [PatElem KernelsMem]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes SegBinOp KernelsMem
scan
                PrimExp ExpLeaf
-> [PrimExp ExpLeaf]
-> [PrimExp ExpLeaf]
-> [PatElem KernelsMem]
-> SegBinOp KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
readCarries (VName -> PrimExp ExpLeaf
Imp.vi32 VName
chunk_offset) [PrimExp ExpLeaf]
dims' [PrimExp ExpLeaf]
vec_is [PatElem KernelsMem]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes SegBinOp KernelsMem
scan)
            ([(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp KernelsMem -> [LParam KernelsMem]
yParams SegBinOp KernelsMem
scan) (SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp KernelsMem
scan)) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
                VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne [])

          SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"combine with carry and write to local memory" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            Names
-> Stms KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT KernelsMem -> Stms KernelsMem)
-> BodyT KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_op) (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            [(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs (BodyT KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT KernelsMem -> [SubExp]) -> BodyT KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_op)) (((TypeBase Shape NoUniqueness, VName, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            \(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
se) -> VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase Shape NoUniqueness -> PrimExp ExpLeaf
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
se []

          let crossesSegment' :: CrossesSegment
crossesSegment' = do
                PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
f <- CrossesSegment
crossesSegment
                (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> CrossesSegment
forall a. a -> Maybe a
Just ((PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
 -> CrossesSegment)
-> (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \PrimExp ExpLeaf
from PrimExp ExpLeaf
to ->
                  let from' :: PrimExp ExpLeaf
from' = PrimExp ExpLeaf
from PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+ VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
chunk_offset PrimType
int32
                      to' :: PrimExp ExpLeaf
to' = PrimExp ExpLeaf
to PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+ VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
chunk_offset PrimType
int32
                  in PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
f PrimExp ExpLeaf
from' PrimExp ExpLeaf
to'

          KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> ImpM KernelsMem KernelEnv KernelOp ())
-> KernelOp -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence

          -- We need to avoid parameter name clashes.
          Lambda KernelsMem
scan_op_renamed <- Lambda KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda KernelsMem
scan_op
          CrossesSegment
-> PrimExp ExpLeaf
-> PrimExp ExpLeaf
-> Lambda KernelsMem
-> [VName]
-> ImpM KernelsMem KernelEnv KernelOp ()
groupScan CrossesSegment
crossesSegment'
            (VName -> PrimExp ExpLeaf
Imp.vi32 VName
num_threads)
            (KernelConstants -> PrimExp ExpLeaf
kernelGroupSize KernelConstants
constants) Lambda KernelsMem
scan_op_renamed [VName]
local_arrs

          SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"threads in bounds write partial scan result" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen PrimExp ExpLeaf
in_bounds (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [(TypeBase Shape NoUniqueness,
  PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((TypeBase Shape NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(TypeBase Shape NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes [VName]
local_arrs) (((TypeBase Shape NoUniqueness,
   PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
            VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids[PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a. [a] -> [a] -> [a]
++[PrimExp ExpLeaf]
vec_is)
            (VName -> SubExp
Var VName
arr) [KernelConstants -> TypeBase Shape NoUniqueness -> PrimExp ExpLeaf
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]

          ImpM KernelsMem KernelEnv KernelOp ()
barrier

          let load_carry :: ImpM KernelsMem KernelEnv KernelOp ()
load_carry =
                [(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_arrs [LParam KernelsMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params) (((VName, Param (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
                VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr)
                [if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p
                 then KernelConstants -> PrimExp ExpLeaf
kernelGroupSize KernelConstants
constants PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
- PrimExp ExpLeaf
1
                 else (KernelConstants -> PrimExp ExpLeaf
kernelGroupId KernelConstants
constantsPrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+PrimExp ExpLeaf
1) PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* KernelConstants -> PrimExp ExpLeaf
kernelGroupSize KernelConstants
constants PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
- PrimExp ExpLeaf
1]
              load_neutral :: ImpM KernelsMem KernelEnv KernelOp ()
load_neutral =
                [(SubExp, Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(SubExp, Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [LParam KernelsMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params) (((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
                VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []

          SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"first thread reads last element as carry-in for next iteration" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
            PrimExp ExpLeaf
crosses_segment <- SpaceId
-> PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op (PrimExp ExpLeaf)
dPrimVE SpaceId
"crosses_segment" (PrimExp ExpLeaf
 -> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf))
-> PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall a b. (a -> b) -> a -> b
$
              case CrossesSegment
crossesSegment of
                CrossesSegment
Nothing -> PrimExp ExpLeaf
forall v. PrimExp v
false
                Just PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
f -> PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
f (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
chunk_offset PrimType
int32 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+
                             KernelConstants -> PrimExp ExpLeaf
kernelGroupSize KernelConstants
constantsPrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
-PrimExp ExpLeaf
1)
                            (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
chunk_offset PrimType
int32 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+
                             KernelConstants -> PrimExp ExpLeaf
kernelGroupSize KernelConstants
constants)
            PrimExp ExpLeaf
should_load_carry <- SpaceId
-> PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op (PrimExp ExpLeaf)
dPrimVE SpaceId
"should_load_carry" (PrimExp ExpLeaf
 -> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf))
-> PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall a b. (a -> b) -> a -> b
$
              KernelConstants -> PrimExp ExpLeaf
kernelLocalThreadId KernelConstants
constants PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. PrimExp ExpLeaf
0 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. UnOp -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
Not PrimExp ExpLeaf
crosses_segment
            PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen PrimExp ExpLeaf
should_load_carry ImpM KernelsMem KernelEnv KernelOp ()
load_carry
            Bool
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM KernelsMem KernelEnv KernelOp ()
barrier
            PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sUnless PrimExp ExpLeaf
should_load_carry ImpM KernelsMem KernelEnv KernelOp ()
load_neutral

          ImpM KernelsMem KernelEnv KernelOp ()
barrier

  (VName, PrimExp ExpLeaf, CrossesSegment)
-> CallKernelGen (VName, PrimExp ExpLeaf, CrossesSegment)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
num_threads, PrimExp ExpLeaf
elems_per_group, CrossesSegment
crossesSegment)

scanStage2 :: Pattern KernelsMem
           -> VName -> Imp.Exp -> Count NumGroups SubExp -> CrossesSegment -> SegSpace
           -> [SegBinOp KernelsMem]
           -> CallKernelGen ()
scanStage2 :: Pattern KernelsMem
-> VName
-> PrimExp ExpLeaf
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp KernelsMem]
-> CallKernelGen ()
scanStage2 (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
all_pes) VName
stage1_num_threads PrimExp ExpLeaf
elems_per_group Count NumGroups SubExp
num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp KernelsMem]
scans = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [PrimExp ExpLeaf]
dims' <- (SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf))
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [PrimExp ExpLeaf]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp [SubExp]
dims

  -- Our group size is the number of groups for the stage 1 kernel.
  let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups
  Count GroupSize (PrimExp ExpLeaf)
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf))
-> Count GroupSize SubExp
-> ImpM
     KernelsMem HostEnv HostOp (Count GroupSize (PrimExp ExpLeaf))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp Count GroupSize SubExp
group_size

  let crossesSegment' :: CrossesSegment
crossesSegment' = do
        PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
f <- CrossesSegment
crossesSegment
        (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> CrossesSegment
forall a. a -> Maybe a
Just ((PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
 -> CrossesSegment)
-> (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \PrimExp ExpLeaf
from PrimExp ExpLeaf
to ->
          PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
f ((PrimExp ExpLeaf
from PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+ PrimExp ExpLeaf
1) PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* PrimExp ExpLeaf
elems_per_group PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
- PrimExp ExpLeaf
1) ((PrimExp ExpLeaf
to PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+ PrimExp ExpLeaf
1) PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* PrimExp ExpLeaf
elems_per_group PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
- PrimExp ExpLeaf
1)

  SpaceId
-> Count NumGroups (PrimExp ExpLeaf)
-> Count GroupSize (PrimExp ExpLeaf)
-> VName
-> ImpM KernelsMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread  SpaceId
"scan_stage2" Count NumGroups (PrimExp ExpLeaf)
1 Count GroupSize (PrimExp ExpLeaf)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
    [[VName]]
per_scan_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp KernelsMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (VName -> SubExp
Var VName
stage1_num_threads) [SegBinOp KernelsMem]
scans
    let per_scan_rets :: [[TypeBase Shape NoUniqueness]]
per_scan_rets = (SegBinOp KernelsMem -> [TypeBase Shape NoUniqueness])
-> [SegBinOp KernelsMem] -> [[TypeBase Shape NoUniqueness]]
forall a b. (a -> b) -> [a] -> [b]
map (Lambda KernelsMem -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType (Lambda KernelsMem -> [TypeBase Shape NoUniqueness])
-> (SegBinOp KernelsMem -> Lambda KernelsMem)
-> SegBinOp KernelsMem
-> [TypeBase Shape NoUniqueness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp KernelsMem]
scans
        per_scan_pes :: [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes = [SegBinOp KernelsMem]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp KernelsMem]
scans [PatElem KernelsMem]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
all_pes

    VName
flat_idx <- SpaceId
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op VName
dPrimV SpaceId
"flat_idx" (PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName)
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
      (KernelConstants -> PrimExp ExpLeaf
kernelLocalThreadId KernelConstants
constants PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+ PrimExp ExpLeaf
1) PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* PrimExp ExpLeaf
elems_per_group PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
- PrimExp ExpLeaf
1
    -- Construct segment indices.
    (VName -> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ())
-> [VName]
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> PrimExp ExpLeaf -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
-> [PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [PrimExp ExpLeaf] -> PrimExp ExpLeaf -> [PrimExp ExpLeaf]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [PrimExp ExpLeaf]
dims' (PrimExp ExpLeaf -> [PrimExp ExpLeaf])
-> PrimExp ExpLeaf -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
flat_idx PrimType
int32

    [(SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
  [PatElemT (MemInfo SubExp NoUniqueness MemBind)])]
-> ((SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]]
-> [[TypeBase Shape NoUniqueness]]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
-> [(SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SegBinOp KernelsMem]
scans [[VName]]
per_scan_local_arrs [[TypeBase Shape NoUniqueness]]
per_scan_rets [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes) (((SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
   [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((SegBinOp KernelsMem, [VName], [TypeBase Shape NoUniqueness],
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      \(SegBinOp Commutativity
_ Lambda KernelsMem
scan_op [SubExp]
nes Shape
vec_shape, [VName]
local_arrs, [TypeBase Shape NoUniqueness]
rets, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) ->
        Shape
-> ([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Shape
-> ([PrimExp ExpLeaf] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
vec_shape (([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[PrimExp ExpLeaf]
vec_is -> do
        let glob_is :: [PrimExp ExpLeaf]
glob_is = (VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids [PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a. [a] -> [a] -> [a]
++ [PrimExp ExpLeaf]
vec_is

            in_bounds :: PrimExp ExpLeaf
in_bounds =
              (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) ([PrimExp ExpLeaf] -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.<.) ((VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids) [PrimExp ExpLeaf]
dims'

            when_in_bounds :: ImpM KernelsMem KernelEnv KernelOp ()
when_in_bounds = [(TypeBase Shape NoUniqueness, VName,
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((TypeBase Shape NoUniqueness, VName,
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(TypeBase Shape NoUniqueness, VName,
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((TypeBase Shape NoUniqueness, VName,
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName,
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
              VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase Shape NoUniqueness -> PrimExp ExpLeaf
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]
              (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [PrimExp ExpLeaf]
glob_is

            when_out_of_bounds :: ImpM KernelsMem KernelEnv KernelOp ()
when_out_of_bounds = [(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [SubExp]
nes) (((TypeBase Shape NoUniqueness, VName, SubExp)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
ne) ->
              VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase Shape NoUniqueness -> PrimExp ExpLeaf
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
ne []
            (Bool
_, Fence
_, ImpM KernelsMem KernelEnv KernelOp ()
barrier) =
              Lambda KernelsMem
-> (Bool, Fence, ImpM KernelsMem KernelEnv KernelOp ())
barrierFor Lambda KernelsMem
scan_op

        SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"threads in bound read carries; others get neutral element" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf PrimExp ExpLeaf
in_bounds ImpM KernelsMem KernelEnv KernelOp ()
when_in_bounds ImpM KernelsMem KernelEnv KernelOp ()
when_out_of_bounds

        ImpM KernelsMem KernelEnv KernelOp ()
barrier

        CrossesSegment
-> PrimExp ExpLeaf
-> PrimExp ExpLeaf
-> Lambda KernelsMem
-> [VName]
-> ImpM KernelsMem KernelEnv KernelOp ()
groupScan CrossesSegment
crossesSegment'
          (VName -> PrimExp ExpLeaf
Imp.vi32 VName
stage1_num_threads) (KernelConstants -> PrimExp ExpLeaf
kernelGroupSize KernelConstants
constants) Lambda KernelsMem
scan_op [VName]
local_arrs

        SpaceId
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. SpaceId -> ImpM lore r op () -> ImpM lore r op ()
sComment SpaceId
"threads in bounds write scanned carries" (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen PrimExp ExpLeaf
in_bounds (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [(TypeBase Shape NoUniqueness,
  PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((TypeBase Shape NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(TypeBase Shape NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes [VName]
local_arrs) (((TypeBase Shape NoUniqueness,
   PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness,
     PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
          VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [PrimExp ExpLeaf]
glob_is
          (VName -> SubExp
Var VName
arr) [KernelConstants -> TypeBase Shape NoUniqueness -> PrimExp ExpLeaf
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]

scanStage3 :: Pattern KernelsMem
           -> Count NumGroups SubExp -> Count GroupSize SubExp
           -> Imp.Exp -> CrossesSegment -> SegSpace
           -> [SegBinOp KernelsMem]
           -> CallKernelGen ()
scanStage3 :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrimExp ExpLeaf
-> CrossesSegment
-> SegSpace
-> [SegBinOp KernelsMem]
-> CallKernelGen ()
scanStage3 (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size PrimExp ExpLeaf
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp KernelsMem]
scans = do
  Count NumGroups (PrimExp ExpLeaf)
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf))
-> Count NumGroups SubExp
-> ImpM
     KernelsMem HostEnv HostOp (Count NumGroups (PrimExp ExpLeaf))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp Count NumGroups SubExp
num_groups
  Count GroupSize (PrimExp ExpLeaf)
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf))
-> Count GroupSize SubExp
-> ImpM
     KernelsMem HostEnv HostOp (Count GroupSize (PrimExp ExpLeaf))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp Count GroupSize SubExp
group_size
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [PrimExp ExpLeaf]
dims' <- (SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf))
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [PrimExp ExpLeaf]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp [SubExp]
dims
  PrimExp ExpLeaf
required_groups <- SpaceId
-> PrimExp ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op (PrimExp ExpLeaf)
dPrimVE SpaceId
"required_groups" (PrimExp ExpLeaf
 -> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf))
-> PrimExp ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (PrimExp ExpLeaf)
forall a b. (a -> b) -> a -> b
$
                     [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [PrimExp ExpLeaf]
dims' PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall e. IntegralExp e => e -> e -> e
`divUp` Count GroupSize (PrimExp ExpLeaf) -> PrimExp ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (PrimExp ExpLeaf)
group_size'

  SpaceId
-> Count NumGroups (PrimExp ExpLeaf)
-> Count GroupSize (PrimExp ExpLeaf)
-> VName
-> ImpM KernelsMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage3" Count NumGroups (PrimExp ExpLeaf)
num_groups' Count GroupSize (PrimExp ExpLeaf)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    SegVirt
-> PrimExp ExpLeaf
-> (VName -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
virtualiseGroups SegVirt
SegVirt PrimExp ExpLeaf
required_groups ((VName -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (VName -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \VName
virt_group_id -> do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

    -- Compute our logical index.
    PrimExp ExpLeaf
flat_idx <- SpaceId
-> PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op (PrimExp ExpLeaf)
dPrimVE SpaceId
"flat_idx" (PrimExp ExpLeaf
 -> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf))
-> PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall a b. (a -> b) -> a -> b
$
                VName -> PrimExp ExpLeaf
Imp.vi32 VName
virt_group_id PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* Count GroupSize (PrimExp ExpLeaf) -> PrimExp ExpLeaf
forall u e. Count u e -> e
unCount Count GroupSize (PrimExp ExpLeaf)
group_size' PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+
                KernelConstants -> PrimExp ExpLeaf
kernelLocalThreadId KernelConstants
constants
    (VName -> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ())
-> [VName]
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op. VName -> PrimExp ExpLeaf -> ImpM lore r op ()
dPrimV_ [VName]
gtids ([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
-> [PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [PrimExp ExpLeaf] -> PrimExp ExpLeaf -> [PrimExp ExpLeaf]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [PrimExp ExpLeaf]
dims' PrimExp ExpLeaf
flat_idx

    -- Figure out which group this element was originally in.
    VName
orig_group <- SpaceId
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op VName
dPrimV SpaceId
"orig_group" (PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName)
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ PrimExp ExpLeaf
flat_idx PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall e. IntegralExp e => e -> e -> e
`quot` PrimExp ExpLeaf
elems_per_group
    -- Then the index of the carry-in of the preceding group.
    VName
carry_in_flat_idx <- SpaceId
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op VName
dPrimV SpaceId
"carry_in_flat_idx" (PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName)
-> PrimExp ExpLeaf -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                         VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
orig_group PrimType
int32 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* PrimExp ExpLeaf
elems_per_group PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
- PrimExp ExpLeaf
1
    -- Figure out the logical index of the carry-in.
    let carry_in_idx :: [PrimExp ExpLeaf]
carry_in_idx = [PrimExp ExpLeaf] -> PrimExp ExpLeaf -> [PrimExp ExpLeaf]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [PrimExp ExpLeaf]
dims' (PrimExp ExpLeaf -> [PrimExp ExpLeaf])
-> PrimExp ExpLeaf -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
carry_in_flat_idx PrimType
int32

    -- Apply the carry if we are not in the scan results for the first
    -- group, and are not the last element in such a group (because
    -- then the carry was updated in stage 2), and we are not crossing
    -- a segment boundary.
    let in_bounds :: PrimExp ExpLeaf
in_bounds =
          (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) ([PrimExp ExpLeaf] -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ (PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.<.) ((VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids) [PrimExp ExpLeaf]
dims'
        crosses_segment :: PrimExp ExpLeaf
crosses_segment = PrimExp ExpLeaf -> Maybe (PrimExp ExpLeaf) -> PrimExp ExpLeaf
forall a. a -> Maybe a -> a
fromMaybe PrimExp ExpLeaf
forall v. PrimExp v
false (Maybe (PrimExp ExpLeaf) -> PrimExp ExpLeaf)
-> Maybe (PrimExp ExpLeaf) -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$
          CrossesSegment
crossesSegment CrossesSegment
-> Maybe (PrimExp ExpLeaf)
-> Maybe (PrimExp ExpLeaf -> PrimExp ExpLeaf)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            PrimExp ExpLeaf -> Maybe (PrimExp ExpLeaf)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
carry_in_flat_idx PrimType
int32) Maybe (PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> Maybe (PrimExp ExpLeaf) -> Maybe (PrimExp ExpLeaf)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            PrimExp ExpLeaf -> Maybe (PrimExp ExpLeaf)
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp ExpLeaf
flat_idx
        is_a_carry :: PrimExp ExpLeaf
is_a_carry = PrimExp ExpLeaf
flat_idx PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==.
                     (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
orig_group PrimType
int32 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
+ PrimExp ExpLeaf
1) PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
* PrimExp ExpLeaf
elems_per_group PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. Num a => a -> a -> a
- PrimExp ExpLeaf
1
        no_carry_in :: PrimExp ExpLeaf
no_carry_in = VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
orig_group PrimType
int32 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. PrimExp ExpLeaf
0 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. PrimExp ExpLeaf
is_a_carry PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. PrimExp ExpLeaf
crosses_segment

    let per_scan_pes :: [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes = [SegBinOp KernelsMem]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
forall lore a. [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks [SegBinOp KernelsMem]
scans [PatElem KernelsMem]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
all_pes
    PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen PrimExp ExpLeaf
in_bounds (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ PrimExp ExpLeaf
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
PrimExp ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sUnless PrimExp ExpLeaf
no_carry_in (ImpM KernelsMem KernelEnv KernelOp ()
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      [([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
  SegBinOp KernelsMem)]
-> (([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp KernelsMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
-> [SegBinOp KernelsMem]
-> [([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp KernelsMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes [SegBinOp KernelsMem]
scans) ((([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
   SegBinOp KernelsMem)
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> (([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp KernelsMem)
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      \([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes, SegBinOp Commutativity
_ Lambda KernelsMem
scan_op [SubExp]
nes Shape
vec_shape) -> do
        Maybe (Exp KernelsMem)
-> Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp KernelsMem)
forall a. Maybe a
Nothing (Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> Scope KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope KernelsMem
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope KernelsMem)
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
scan_op
        let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params) =
              Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
scan_op

        Shape
-> ([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
Shape
-> ([PrimExp ExpLeaf] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
vec_shape (([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ([PrimExp ExpLeaf] -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[PrimExp ExpLeaf]
vec_is -> do
          [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
            VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []
            (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ([PrimExp ExpLeaf]
carry_in_idx[PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a. [a] -> [a] -> [a]
++[PrimExp ExpLeaf]
vec_is)

          [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
            VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []
            (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids[PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a. [a] -> [a] -> [a]
++[PrimExp ExpLeaf]
vec_is)

          [Param (MemInfo SubExp NoUniqueness MemBind)]
-> BodyT KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params (BodyT KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ())
-> BodyT KernelsMem -> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_op

          [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElemT (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM KernelsMem KernelEnv KernelOp ())
 -> ImpM KernelsMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElemT (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM KernelsMem KernelEnv KernelOp ())
-> ImpM KernelsMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ->
            VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM KernelsMem KernelEnv KernelOp ()
forall lore r op.
VName
-> [PrimExp ExpLeaf]
-> SubExp
-> [PrimExp ExpLeaf]
-> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> PrimExp ExpLeaf) -> [VName] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map VName -> PrimExp ExpLeaf
Imp.vi32 [VName]
gtids[PrimExp ExpLeaf] -> [PrimExp ExpLeaf] -> [PrimExp ExpLeaf]
forall a. [a] -> [a] -> [a]
++[PrimExp ExpLeaf]
vec_is)
            (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan :: Pattern KernelsMem
               -> SegLevel -> SegSpace
               -> [SegBinOp KernelsMem]
               -> KernelBody KernelsMem
               -> CallKernelGen ()
compileSegScan :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody = PrimExp ExpLeaf -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
PrimExp ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen (PrimExp ExpLeaf
0 PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimExp ExpLeaf
n) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ SpaceId -> Maybe (PrimExp ExpLeaf) -> Code HostOp
forall a. SpaceId -> Maybe (PrimExp ExpLeaf) -> Code a
Imp.DebugPrint SpaceId
"\n# SegScan" Maybe (PrimExp ExpLeaf)
forall a. Maybe a
Nothing

  -- Since stage 2 involves a group size equal to the number of groups
  -- used for stage 1, we have to cap this number to the maximum group
  -- size.
  VName
stage1_max_num_groups <-
    SpaceId -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. SpaceId -> PrimType -> ImpM lore r op VName
dPrim SpaceId
"stage1_max_num_groups" PrimType
int32
  HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax VName
stage1_max_num_groups SizeClass
SizeGroup

  Count NumGroups SubExp
stage1_num_groups <-
    (VName -> Count NumGroups SubExp)
-> ImpM KernelsMem HostEnv HostOp VName
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (VName -> SubExp) -> VName -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) (ImpM KernelsMem HostEnv HostOp VName
 -> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM KernelsMem HostEnv HostOp VName
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$ SpaceId -> PrimExp ExpLeaf -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
SpaceId -> PrimExp ExpLeaf -> ImpM lore r op VName
dPrimV SpaceId
"stage1_num_groups" (PrimExp ExpLeaf -> ImpM KernelsMem HostEnv HostOp VName)
-> PrimExp ExpLeaf -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
    BinOp -> PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) (VName -> PrimExp ExpLeaf
Imp.vi32 VName
stage1_max_num_groups) (PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$
    PrimType -> SubExp -> PrimExp ExpLeaf
forall a. ToExp a => PrimType -> a -> PrimExp ExpLeaf
toExp' PrimType
int32 (SubExp -> PrimExp ExpLeaf) -> SubExp -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
Imp.unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl

  (VName
stage1_num_threads, PrimExp ExpLeaf
elems_per_group, CrossesSegment
crossesSegment) <-
    Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen (VName, PrimExp ExpLeaf, CrossesSegment)
scanStage1 Pattern KernelsMem
pat Count NumGroups SubExp
stage1_num_groups (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ SpaceId -> Maybe (PrimExp ExpLeaf) -> Code HostOp
forall a. SpaceId -> Maybe (PrimExp ExpLeaf) -> Code a
Imp.DebugPrint SpaceId
"elems_per_group" (Maybe (PrimExp ExpLeaf) -> Code HostOp)
-> Maybe (PrimExp ExpLeaf) -> Code HostOp
forall a b. (a -> b) -> a -> b
$ PrimExp ExpLeaf -> Maybe (PrimExp ExpLeaf)
forall a. a -> Maybe a
Just PrimExp ExpLeaf
elems_per_group

  Pattern KernelsMem
-> VName
-> PrimExp ExpLeaf
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp KernelsMem]
-> CallKernelGen ()
scanStage2 Pattern KernelsMem
pat VName
stage1_num_threads PrimExp ExpLeaf
elems_per_group Count NumGroups SubExp
stage1_num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp KernelsMem]
scans
  Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrimExp ExpLeaf
-> CrossesSegment
-> SegSpace
-> [SegBinOp KernelsMem]
-> CallKernelGen ()
scanStage3 Pattern KernelsMem
pat (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) PrimExp ExpLeaf
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp KernelsMem]
scans
  where n :: PrimExp ExpLeaf
n = [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([PrimExp ExpLeaf] -> PrimExp ExpLeaf)
-> [PrimExp ExpLeaf] -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp ExpLeaf) -> [SubExp] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp ExpLeaf
forall a. ToExp a => PrimType -> a -> PrimExp ExpLeaf
toExp' PrimType
int32) ([SubExp] -> [PrimExp ExpLeaf]) -> [SubExp] -> [PrimExp ExpLeaf]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space