{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Code generation for segmented and non-segmented scans.  Uses a
-- fairly inefficient two-pass algorithm, but can handle anything.
module Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass (compileSegScan) where

import Control.Monad.Except
import Control.Monad.State
import Data.List (delete, find, foldl', zip4)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- Aggressively try to reuse memory for different SegBinOps, because
-- we will run them sequentially after another.
makeLocalArrays ::
  Count GroupSize SubExp ->
  SubExp ->
  [SegBinOp GPUMem] ->
  InKernelGen [[VName]]
makeLocalArrays :: Count GroupSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays (Count SubExp
group_size) SubExp
num_threads [SegBinOp GPUMem]
scans = do
  ([[VName]]
arrs, [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems_and_sizes) <- StateT
  [([Count Bytes (TPrimExp Int64 VName)], VName)]
  (ImpM GPUMem KernelEnv KernelOp)
  [[VName]]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     ([[VName]], [([Count Bytes (TPrimExp Int64 VName)], VName)])
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((SegBinOp GPUMem
 -> StateT
      [([Count Bytes (TPrimExp Int64 VName)], VName)]
      (ImpM GPUMem KernelEnv KernelOp)
      [VName])
-> [SegBinOp GPUMem]
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPUMem
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [VName]
onScan [SegBinOp GPUMem]
scans) [([Count Bytes (TPrimExp Int64 VName)], VName)]
forall a. Monoid a => a
mempty
  let maxSize :: [Count u (TPrimExp Int64 v)] -> Count Bytes (TPrimExp Int64 v)
maxSize [Count u (TPrimExp Int64 v)]
sizes = TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v))
-> TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v)
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v)
-> TPrimExp Int64 v -> [TPrimExp Int64 v] -> TPrimExp Int64 v
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 v
1 ([TPrimExp Int64 v] -> TPrimExp Int64 v)
-> [TPrimExp Int64 v] -> TPrimExp Int64 v
forall a b. (a -> b) -> a -> b
$ (Count u (TPrimExp Int64 v) -> TPrimExp Int64 v)
-> [Count u (TPrimExp Int64 v)] -> [TPrimExp Int64 v]
forall a b. (a -> b) -> [a] -> [b]
map Count u (TPrimExp Int64 v) -> TPrimExp Int64 v
forall u e. Count u e -> e
Imp.unCount [Count u (TPrimExp Int64 v)]
sizes
  [([Count Bytes (TPrimExp Int64 VName)], VName)]
-> (([Count Bytes (TPrimExp Int64 VName)], VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems_and_sizes ((([Count Bytes (TPrimExp Int64 VName)], VName)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (([Count Bytes (TPrimExp Int64 VName)], VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \([Count Bytes (TPrimExp Int64 VName)]
sizes, VName
mem) ->
    VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Count Bytes (TPrimExp Int64 VName) -> Space -> ImpM rep r op ()
sAlloc_ VName
mem ([Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall v u.
Pretty v =>
[Count u (TPrimExp Int64 v)] -> Count Bytes (TPrimExp Int64 v)
maxSize [Count Bytes (TPrimExp Int64 VName)]
sizes) (SpaceId -> Space
Space SpaceId
"local")
  [[VName]] -> InKernelGen [[VName]]
forall (m :: * -> *) a. Monad m => a -> m a
return [[VName]]
arrs
  where
    onScan :: SegBinOp GPUMem
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [VName]
onScan (SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
_) = do
      let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
_scan_y_params) =
            Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
      ([VName]
arrs, [[([Count Bytes (TPrimExp Int64 VName)], VName)]]
used_mems) <- ([(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
 -> ([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]]))
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     ([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
-> ([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]])
forall a b. [(a, b)] -> ([a], [b])
unzip (StateT
   [([Count Bytes (TPrimExp Int64 VName)], VName)]
   (ImpM GPUMem KernelEnv KernelOp)
   [(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
 -> StateT
      [([Count Bytes (TPrimExp Int64 VName)], VName)]
      (ImpM GPUMem KernelEnv KernelOp)
      ([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]]))
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     ([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]])
forall a b. (a -> b) -> a -> b
$
        [Param (MemInfo SubExp NoUniqueness MemBind)]
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> StateT
         [([Count Bytes (TPrimExp Int64 VName)], VName)]
         (ImpM GPUMem KernelEnv KernelOp)
         (VName, [([Count Bytes (TPrimExp Int64 VName)], VName)]))
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params ((Param (MemInfo SubExp NoUniqueness MemBind)
  -> StateT
       [([Count Bytes (TPrimExp Int64 VName)], VName)]
       (ImpM GPUMem KernelEnv KernelOp)
       (VName, [([Count Bytes (TPrimExp Int64 VName)], VName)]))
 -> StateT
      [([Count Bytes (TPrimExp Int64 VName)], VName)]
      (ImpM GPUMem KernelEnv KernelOp)
      [(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])])
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> StateT
         [([Count Bytes (TPrimExp Int64 VName)], VName)]
         (ImpM GPUMem KernelEnv KernelOp)
         (VName, [([Count Bytes (TPrimExp Int64 VName)], VName)]))
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
forall a b. (a -> b) -> a -> b
$ \Param (MemInfo SubExp NoUniqueness MemBind)
p ->
          case Param (MemInfo SubExp NoUniqueness MemBind)
-> MemInfo SubExp NoUniqueness MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp NoUniqueness MemBind)
p of
            MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
              let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
              VName
arr <-
                ImpM GPUMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM GPUMem KernelEnv KernelOp VName
 -> StateT
      [([Count Bytes (TPrimExp Int64 VName)], VName)]
      (ImpM GPUMem KernelEnv KernelOp)
      VName)
-> (IxFun -> ImpM GPUMem KernelEnv KernelOp VName)
-> IxFun
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SpaceId
-> PrimType
-> Shape
-> VName
-> IxFun
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
SpaceId
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray SpaceId
"scan_arr" PrimType
pt Shape
shape' VName
mem (IxFun
 -> StateT
      [([Count Bytes (TPrimExp Int64 VName)], VName)]
      (ImpM GPUMem KernelEnv KernelOp)
      VName)
-> IxFun
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     VName
forall a b. (a -> b) -> a -> b
$
                  Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
              (VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     (VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, [])
            MemInfo SubExp NoUniqueness MemBind
_ -> do
              let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p
                  shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
              ([Count Bytes (TPrimExp Int64 VName)]
sizes, VName
mem') <- PrimType
-> Shape
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     ([Count Bytes (TPrimExp Int64 VName)], VName)
forall (t :: (* -> *) -> * -> *) rep r op.
(MonadState
   [([Count Bytes (TPrimExp Int64 VName)], VName)]
   (t (ImpM rep r op)),
 MonadTrans t) =>
PrimType
-> Shape
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
getMem PrimType
pt Shape
shape
              VName
arr <- ImpM GPUMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM GPUMem KernelEnv KernelOp VName
 -> StateT
      [([Count Bytes (TPrimExp Int64 VName)], VName)]
      (ImpM GPUMem KernelEnv KernelOp)
      VName)
-> ImpM GPUMem KernelEnv KernelOp VName
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     VName
forall a b. (a -> b) -> a -> b
$ SpaceId
-> PrimType
-> Shape
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
SpaceId -> PrimType -> Shape -> VName -> ImpM rep r op VName
sArrayInMem SpaceId
"scan_arr" PrimType
pt Shape
shape VName
mem'
              (VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     (VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, [([Count Bytes (TPrimExp Int64 VName)]
sizes, VName
mem')])
      ([([Count Bytes (TPrimExp Int64 VName)], VName)]
 -> [([Count Bytes (TPrimExp Int64 VName)], VName)])
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([([Count Bytes (TPrimExp Int64 VName)], VName)]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
forall a. Semigroup a => a -> a -> a
<> [[([Count Bytes (TPrimExp Int64 VName)], VName)]]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[([Count Bytes (TPrimExp Int64 VName)], VName)]]
used_mems)
      [VName]
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
arrs

    getMem :: PrimType
-> Shape
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
getMem PrimType
pt Shape
shape = do
      let size :: Count Bytes (TPrimExp Int64 VName)
size = TypeBase Shape NoUniqueness -> Count Bytes (TPrimExp Int64 VName)
typeSize (TypeBase Shape NoUniqueness -> Count Bytes (TPrimExp Int64 VName))
-> TypeBase Shape NoUniqueness
-> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness
      [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems <- t (ImpM rep r op) [([Count Bytes (TPrimExp Int64 VName)], VName)]
forall s (m :: * -> *). MonadState s m => m s
get
      case ((([Count Bytes (TPrimExp Int64 VName)], VName) -> Bool)
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
-> Maybe ([Count Bytes (TPrimExp Int64 VName)], VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Count Bytes (TPrimExp Int64 VName)
size Count Bytes (TPrimExp Int64 VName)
-> [Count Bytes (TPrimExp Int64 VName)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem`) ([Count Bytes (TPrimExp Int64 VName)] -> Bool)
-> (([Count Bytes (TPrimExp Int64 VName)], VName)
    -> [Count Bytes (TPrimExp Int64 VName)])
-> ([Count Bytes (TPrimExp Int64 VName)], VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Count Bytes (TPrimExp Int64 VName)], VName)
-> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a, b) -> a
fst) [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems, [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems) of
        (Just ([Count Bytes (TPrimExp Int64 VName)], VName)
mem, [([Count Bytes (TPrimExp Int64 VName)], VName)]
_) -> do
          ([([Count Bytes (TPrimExp Int64 VName)], VName)]
 -> [([Count Bytes (TPrimExp Int64 VName)], VName)])
-> t (ImpM rep r op) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (([([Count Bytes (TPrimExp Int64 VName)], VName)]
  -> [([Count Bytes (TPrimExp Int64 VName)], VName)])
 -> t (ImpM rep r op) ())
-> ([([Count Bytes (TPrimExp Int64 VName)], VName)]
    -> [([Count Bytes (TPrimExp Int64 VName)], VName)])
-> t (ImpM rep r op) ()
forall a b. (a -> b) -> a -> b
$ ([Count Bytes (TPrimExp Int64 VName)], VName)
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
forall a. Eq a => a -> [a] -> [a]
delete ([Count Bytes (TPrimExp Int64 VName)], VName)
mem
          ([Count Bytes (TPrimExp Int64 VName)], VName)
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Count Bytes (TPrimExp Int64 VName)], VName)
mem
        (Maybe ([Count Bytes (TPrimExp Int64 VName)], VName)
Nothing, ([Count Bytes (TPrimExp Int64 VName)]
size', VName
mem) : [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems') -> do
          [([Count Bytes (TPrimExp Int64 VName)], VName)]
-> t (ImpM rep r op) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems'
          ([Count Bytes (TPrimExp Int64 VName)], VName)
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Count Bytes (TPrimExp Int64 VName)
size Count Bytes (TPrimExp Int64 VName)
-> [Count Bytes (TPrimExp Int64 VName)]
-> [Count Bytes (TPrimExp Int64 VName)]
forall a. a -> [a] -> [a]
: [Count Bytes (TPrimExp Int64 VName)]
size', VName
mem)
        (Maybe ([Count Bytes (TPrimExp Int64 VName)], VName)
Nothing, []) -> do
          VName
mem <- ImpM rep r op VName -> t (ImpM rep r op) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM rep r op VName -> t (ImpM rep r op) VName)
-> ImpM rep r op VName -> t (ImpM rep r op) VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space -> ImpM rep r op VName
forall rep r op. SpaceId -> Space -> ImpM rep r op VName
sDeclareMem SpaceId
"scan_arr_mem" (Space -> ImpM rep r op VName) -> Space -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space
Space SpaceId
"local"
          ([Count Bytes (TPrimExp Int64 VName)], VName)
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Count Bytes (TPrimExp Int64 VName)
size], VName
mem)

type CrossesSegment = Maybe (Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Bool)

localArrayIndex :: KernelConstants -> Type -> Imp.TExp Int64
localArrayIndex :: KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t =
  if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t
    then TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants)
    else TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelGlobalThreadId KernelConstants
constants)

barrierFor :: Lambda GPUMem -> (Bool, Imp.Fence, InKernelGen ())
barrierFor :: Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op = (Bool
array_scan, Fence
fence, KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence)
  where
    array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_op
    fence :: Fence
fence
      | Bool
array_scan = Fence
Imp.FenceGlobal
      | Bool
otherwise = Fence
Imp.FenceLocal

xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
  Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
  Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))

writeToScanValues ::
  [VName] ->
  ([PatElem LetDecMem], SegBinOp GPUMem, [KernelResult]) ->
  InKernelGen ()
writeToScanValues :: [VName]
-> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
    SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes, SegBinOp GPUMem
scan, [KernelResult]
scan_res)
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
    [(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [KernelResult]
scan_res) (((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (MemInfo SubExp NoUniqueness MemBind)
pe, KernelResult
res) ->
      VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix
        (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)
        ((VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids)
        (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
        []
  | Bool
otherwise =
    [(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) [KernelResult]
scan_res) (((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, KernelResult
res) ->
      VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

readToScanValues ::
  [Imp.TExp Int64] ->
  [PatElem LetDecMem] ->
  SegBinOp GPUMem ->
  InKernelGen ()
readToScanValues :: Shape (TPrimExp Int64 VName)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readToScanValues Shape (TPrimExp Int64 VName)
is [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes SegBinOp GPUMem
scan
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
    [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElem (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElem (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElem (MemInfo SubExp NoUniqueness MemBind)
pe) ->
      VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)) Shape (TPrimExp Int64 VName)
is
  | Bool
otherwise =
    () -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

readCarries ::
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  [Imp.TExp Int64] ->
  [Imp.TExp Int64] ->
  [PatElem LetDecMem] ->
  SegBinOp GPUMem ->
  InKernelGen ()
readCarries :: TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readCarries TPrimExp Int64 VName
chunk_id TPrimExp Int64 VName
chunk_offset Shape (TPrimExp Int64 VName)
dims' Shape (TPrimExp Int64 VName)
vec_is [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes SegBinOp GPUMem
scan
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
    TPrimExp Int32 VName
ltid <- KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId (KernelConstants -> TPrimExp Int32 VName)
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> TPrimExp Int32 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int32 VName)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
    -- We may have to reload the carries from the output of the
    -- previous chunk.
    TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
      (TPrimExp Int64 VName
chunk_id TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 VName
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int32 VName
ltid TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0)
      ( do
          let is :: Shape (TPrimExp Int64 VName)
is = Shape (TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName)
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex Shape (TPrimExp Int64 VName)
dims' (TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
chunk_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
          [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElem (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElem (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElem (MemInfo SubExp NoUniqueness MemBind)
pe) ->
            VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)) (Shape (TPrimExp Int64 VName)
is Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ Shape (TPrimExp Int64 VName)
vec_is)
      )
      ( [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
          VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []
      )
  | Bool
otherwise =
    () -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Produce partially scanned intervals; one per workgroup.
scanStage1 ::
  Pat LetDecMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen (TV Int32, Imp.TExp Int64, CrossesSegment)
scanStage1 :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen (TV Int32, TPrimExp Int64 VName, CrossesSegment)
scanStage1 (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = do
  let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = (SubExp -> TPrimExp Int64 VName)
-> Count NumGroups SubExp -> Count NumGroups (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp Count GroupSize SubExp
group_size
  TV Int32
num_threads <- SpaceId
-> TPrimExp Int32 VName -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"num_threads" (TPrimExp Int32 VName -> ImpM GPUMem HostEnv HostOp (TV Int32))
-> TPrimExp Int32 VName -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TPrimExp Int32 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TPrimExp Int32 VName)
-> TPrimExp Int64 VName -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'

  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: Shape (TPrimExp Int64 VName)
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp [SubExp]
dims
  let num_elements :: TPrimExp Int64 VName
num_elements = Shape (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape (TPrimExp Int64 VName)
dims'
      elems_per_thread :: TPrimExp Int64 VName
elems_per_thread = TPrimExp Int64 VName
num_elements TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TPrimExp Int32 VName
forall t. TV t -> TExp t
tvExp TV Int32
num_threads)
      elems_per_group :: TPrimExp Int64 VName
elems_per_group = Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_thread

  let crossesSegment :: CrossesSegment
crossesSegment =
        case Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a]
reverse Shape (TPrimExp Int64 VName)
dims' of
          TPrimExp Int64 VName
segment_size : TPrimExp Int64 VName
_ : Shape (TPrimExp Int64 VName)
_ -> (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> CrossesSegment
forall a. a -> Maybe a
Just ((TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
 -> CrossesSegment)
-> (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
from TPrimExp Int64 VName
to ->
            (TPrimExp Int64 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TPrimExp Int64 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)
          Shape (TPrimExp Int64 VName)
_ -> CrossesSegment
forall a. Maybe a
Nothing

  SpaceId
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage1" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups (TPrimExp Int64 VName)
-> Count GroupSize (TPrimExp Int64 VName) -> KernelAttrs
defKernelAttrs Count NumGroups (TPrimExp Int64 VName)
num_groups' Count GroupSize (TPrimExp Int64 VName)
group_size') (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
    [[VName]]
all_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_threads) [SegBinOp GPUMem]
scans

    -- The variables from scan_op will be used for the carry and such
    -- in the big chunking loop.
    [SegBinOp GPUMem]
-> (SegBinOp GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOp GPUMem]
scans ((SegBinOp GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (SegBinOp GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \SegBinOp GPUMem
scan -> do
      Maybe (Exp GPUMem)
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall rep inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp GPUMem)
forall a. Maybe a
Nothing (Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope GPUMem)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [LParam GPUMem])
-> Lambda GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan
      [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
        VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []

    SpaceId
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
SpaceId
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor SpaceId
"j" TPrimExp Int64 VName
elems_per_thread ((TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
j -> do
      TV Int64
chunk_offset <-
        SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"chunk_offset" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
j
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group
      TV Int64
flat_idx <-
        SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"flat_idx" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
          TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants)
      -- Construct segment indices.
      (VName
 -> TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [VName]
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
gtids (Shape (TPrimExp Int64 VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName)
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex Shape (TPrimExp Int64 VName)
dims' (TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx

      let per_scan_pes :: [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes = [SegBinOp GPUMem]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes

          in_bounds :: TExp Bool
in_bounds =
            (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName)
-> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) Shape (TPrimExp Int64 VName)
dims'

          when_in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds = Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
            let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
                  Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
scans) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
                per_scan_res :: [[KernelResult]]
per_scan_res =
                  [SegBinOp GPUMem] -> [KernelResult] -> [[KernelResult]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [KernelResult]
all_scan_res

            SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"write to-scan values to parameters" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              (([PatElem (MemInfo SubExp NoUniqueness MemBind)], SegBinOp GPUMem,
  [KernelResult])
 -> ImpM GPUMem KernelEnv KernelOp ())
-> [([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp GPUMem, [KernelResult])]
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([VName]
-> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
    SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids) ([([PatElem (MemInfo SubExp NoUniqueness MemBind)],
   SegBinOp GPUMem, [KernelResult])]
 -> ImpM GPUMem KernelEnv KernelOp ())
-> [([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp GPUMem, [KernelResult])]
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
-> [SegBinOp GPUMem]
-> [[KernelResult]]
-> [([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp GPUMem, [KernelResult])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes [SegBinOp GPUMem]
scans [[KernelResult]]
per_scan_res

            SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"write mapped values results to global memory" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              [(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) [KernelResult]
map_res) (((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (MemInfo SubExp NoUniqueness MemBind)
pe, KernelResult
se) ->
                VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix
                  (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)
                  ((VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids)
                  (KernelResult -> SubExp
kernelResultSubExp KernelResult
se)
                  []

      SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"threads in bounds read input" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds

      Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((SegBinOp GPUMem -> Bool) -> [SegBinOp GPUMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Shape -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Shape -> Bool)
-> (SegBinOp GPUMem -> Shape) -> SegBinOp GPUMem -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape) [SegBinOp GPUMem]
scans) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

      [([PatElem (MemInfo SubExp NoUniqueness MemBind)], SegBinOp GPUMem,
  [VName])]
-> (([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp GPUMem, [VName])
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
-> [SegBinOp GPUMem]
-> [[VName]]
-> [([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp GPUMem, [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes [SegBinOp GPUMem]
scans [[VName]]
all_local_arrs) ((([PatElem (MemInfo SubExp NoUniqueness MemBind)],
   SegBinOp GPUMem, [VName])
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp GPUMem, [VName])
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        \([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes, scan :: SegBinOp GPUMem
scan@(SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape), [VName]
local_arrs) ->
          SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"do one intra-group scan operation" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
            let rets :: [TypeBase Shape NoUniqueness]
rets = Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_op
                scan_x_params :: [LParam GPUMem]
scan_x_params = SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan
                (Bool
array_scan, Fence
fence, ImpM GPUMem KernelEnv KernelOp ()
barrier) = Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op

            Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier

            Shape
-> (Shape (TPrimExp Int64 VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Shape
-> (Shape (TPrimExp Int64 VName) -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest Shape
vec_shape ((Shape (TPrimExp Int64 VName)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (Shape (TPrimExp Int64 VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \Shape (TPrimExp Int64 VName)
vec_is -> do
              SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"maybe restore some to-scan values to parameters, or read neutral" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                  TExp Bool
in_bounds
                  ( do
                      Shape (TPrimExp Int64 VName)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readToScanValues ((VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ Shape (TPrimExp Int64 VName)
vec_is) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes SegBinOp GPUMem
scan
                      TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readCarries TPrimExp Int64 VName
j (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset) Shape (TPrimExp Int64 VName)
dims' Shape (TPrimExp Int64 VName)
vec_is [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes SegBinOp GPUMem
scan
                  )
                  ( [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
ne) ->
                      VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []
                  )

              SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"combine with carry and write to local memory" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  [(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs ([SubExp] -> [(TypeBase Shape NoUniqueness, VName, SubExp)])
-> [SubExp] -> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op) (((TypeBase Shape NoUniqueness, VName, SubExp)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                    \(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
se) ->
                      VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix VName
arr [KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
se []

              let crossesSegment' :: Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment' = do
                    TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f <- CrossesSegment
crossesSegment
                    (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
-> Maybe
     (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
forall a. a -> Maybe a
Just ((TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
 -> Maybe
      (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool))
-> (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
-> Maybe
     (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
from TPrimExp Int32 VName
to ->
                      let from' :: TPrimExp Int64 VName
from' = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
                          to' :: TPrimExp Int64 VName
to' = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
                       in TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f TPrimExp Int64 VName
from' TPrimExp Int64 VName
to'

              KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence

              -- We need to avoid parameter name clashes.
              Lambda GPUMem
scan_op_renamed <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scan_op
              Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
                Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment'
                (TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TPrimExp Int64 VName)
-> TPrimExp Int32 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 VName
forall t. TV t -> TExp t
tvExp TV Int32
num_threads)
                (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
                Lambda GPUMem
scan_op_renamed
                [VName]
local_arrs

              SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"threads in bounds write partial scan result" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                  [(TypeBase Shape NoUniqueness,
  PatElem (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((TypeBase Shape NoUniqueness,
     PatElem (MemInfo SubExp NoUniqueness MemBind), VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(TypeBase Shape NoUniqueness,
     PatElem (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [VName]
local_arrs) (((TypeBase Shape NoUniqueness,
   PatElem (MemInfo SubExp NoUniqueness MemBind), VName)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness,
     PatElem (MemInfo SubExp NoUniqueness MemBind), VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, PatElem (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
                    VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix
                      (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)
                      ((VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ Shape (TPrimExp Int64 VName)
vec_is)
                      (VName -> SubExp
Var VName
arr)
                      [KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]

              ImpM GPUMem KernelEnv KernelOp ()
barrier

              let load_carry :: ImpM GPUMem KernelEnv KernelOp ()
load_carry =
                    [(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_arrs [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params) (((VName, Param (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
                      VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix
                        (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p)
                        []
                        (VName -> SubExp
Var VName
arr)
                        [ if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p
                            then TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
                            else
                              (TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
                                TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
                        ]
                  load_neutral :: ImpM GPUMem KernelEnv KernelOp ()
load_neutral =
                    [(SubExp, Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(SubExp, Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params) (((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((SubExp, Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, Param (MemInfo SubExp NoUniqueness MemBind)
p) ->
                      VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
ne []

              SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"first thread reads last element as carry-in for next iteration" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                TExp Bool
crosses_segment <- SpaceId -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"crosses_segment" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$
                  case CrossesSegment
crossesSegment of
                    CrossesSegment
Nothing -> TExp Bool
forall v. TPrimExp Bool v
false
                    Just TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f ->
                      TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f
                        ( TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
                            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
-TPrimExp Int64 VName
1
                        )
                        ( TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
                            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
                        )
                TExp Bool
should_load_carry <-
                  SpaceId -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"should_load_carry" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$
                    KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
crosses_segment
                TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
should_load_carry ImpM GPUMem KernelEnv KernelOp ()
load_carry
                Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier
                TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
should_load_carry ImpM GPUMem KernelEnv KernelOp ()
load_neutral

              ImpM GPUMem KernelEnv KernelOp ()
barrier

  (TV Int32, TPrimExp Int64 VName, CrossesSegment)
-> CallKernelGen (TV Int32, TPrimExp Int64 VName, CrossesSegment)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Int32
num_threads, TPrimExp Int64 VName
elems_per_group, CrossesSegment
crossesSegment)

scanStage2 ::
  Pat LetDecMem ->
  TV Int32 ->
  Imp.TExp Int64 ->
  Count NumGroups SubExp ->
  CrossesSegment ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  CallKernelGen ()
scanStage2 :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> TV Int32
-> TPrimExp Int64 VName
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage2 (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) TV Int32
stage1_num_threads TPrimExp Int64 VName
elems_per_group Count NumGroups SubExp
num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans = do
  let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: Shape (TPrimExp Int64 VName)
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp [SubExp]
dims

  -- Our group size is the number of groups for the stage 1 kernel.
  let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp Count GroupSize SubExp
group_size

  let crossesSegment' :: Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment' = do
        TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f <- CrossesSegment
crossesSegment
        (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
-> Maybe
     (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
forall a. a -> Maybe a
Just ((TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
 -> Maybe
      (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool))
-> (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
-> Maybe
     (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
from TPrimExp Int32 VName
to ->
          TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f
            ((TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
            ((TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)

  SpaceId
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage2" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups (TPrimExp Int64 VName)
-> Count GroupSize (TPrimExp Int64 VName) -> KernelAttrs
defKernelAttrs Count NumGroups (TPrimExp Int64 VName)
1 Count GroupSize (TPrimExp Int64 VName)
group_size') (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
    [[VName]]
per_scan_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
stage1_num_threads) [SegBinOp GPUMem]
scans
    let per_scan_rets :: [[TypeBase Shape NoUniqueness]]
per_scan_rets = (SegBinOp GPUMem -> [TypeBase Shape NoUniqueness])
-> [SegBinOp GPUMem] -> [[TypeBase Shape NoUniqueness]]
forall a b. (a -> b) -> [a] -> [b]
map (Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase Shape NoUniqueness])
-> (SegBinOp GPUMem -> Lambda GPUMem)
-> SegBinOp GPUMem
-> [TypeBase Shape NoUniqueness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
scans
        per_scan_pes :: [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes = [SegBinOp GPUMem]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes

    TV Int64
flat_idx <-
      SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"flat_idx" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
        (TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
    -- Construct segment indices.
    (VName
 -> TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [VName]
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
gtids (Shape (TPrimExp Int64 VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName)
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex Shape (TPrimExp Int64 VName)
dims' (TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx

    [(SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
  [PatElem (MemInfo SubExp NoUniqueness MemBind)])]
-> ((SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
     [PatElem (MemInfo SubExp NoUniqueness MemBind)])
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem]
-> [[VName]]
-> [[TypeBase Shape NoUniqueness]]
-> [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
-> [(SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
     [PatElem (MemInfo SubExp NoUniqueness MemBind)])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SegBinOp GPUMem]
scans [[VName]]
per_scan_local_arrs [[TypeBase Shape NoUniqueness]]
per_scan_rets [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes) (((SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
   [PatElem (MemInfo SubExp NoUniqueness MemBind)])
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
     [PatElem (MemInfo SubExp NoUniqueness MemBind)])
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
      \(SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape, [VName]
local_arrs, [TypeBase Shape NoUniqueness]
rets, [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) ->
        Shape
-> (Shape (TPrimExp Int64 VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Shape
-> (Shape (TPrimExp Int64 VName) -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest Shape
vec_shape ((Shape (TPrimExp Int64 VName)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (Shape (TPrimExp Int64 VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \Shape (TPrimExp Int64 VName)
vec_is -> do
          let glob_is :: Shape (TPrimExp Int64 VName)
glob_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ Shape (TPrimExp Int64 VName)
vec_is

              in_bounds :: TExp Bool
in_bounds =
                (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName)
-> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) Shape (TPrimExp Int64 VName)
dims'

              when_in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds = [(TypeBase Shape NoUniqueness, VName,
  PatElem (MemInfo SubExp NoUniqueness MemBind))]
-> ((TypeBase Shape NoUniqueness, VName,
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [(TypeBase Shape NoUniqueness, VName,
     PatElem (MemInfo SubExp NoUniqueness MemBind))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) (((TypeBase Shape NoUniqueness, VName,
   PatElem (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName,
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, PatElem (MemInfo SubExp NoUniqueness MemBind)
pe) ->
                VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix
                  VName
arr
                  [KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]
                  (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)
                  Shape (TPrimExp Int64 VName)
glob_is

              when_out_of_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_out_of_bounds = [(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [SubExp]
nes) (((TypeBase Shape NoUniqueness, VName, SubExp)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
ne) ->
                VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix VName
arr [KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
ne []
              (Bool
_, Fence
_, ImpM GPUMem KernelEnv KernelOp ()
barrier) =
                Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op

          SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"threads in bound read carries; others get neutral element" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_out_of_bounds

          ImpM GPUMem KernelEnv KernelOp ()
barrier

          Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
            Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment'
            (TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TPrimExp Int64 VName)
-> TPrimExp Int32 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 VName
forall t. TV t -> TExp t
tvExp TV Int32
stage1_num_threads)
            (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
            Lambda GPUMem
scan_op
            [VName]
local_arrs

          SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"threads in bounds write scanned carries" (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              [(TypeBase Shape NoUniqueness,
  PatElem (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((TypeBase Shape NoUniqueness,
     PatElem (MemInfo SubExp NoUniqueness MemBind), VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(TypeBase Shape NoUniqueness,
     PatElem (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [VName]
local_arrs) (((TypeBase Shape NoUniqueness,
   PatElem (MemInfo SubExp NoUniqueness MemBind), VName)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness,
     PatElem (MemInfo SubExp NoUniqueness MemBind), VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, PatElem (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
                VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix
                  (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)
                  Shape (TPrimExp Int64 VName)
glob_is
                  (VName -> SubExp
Var VName
arr)
                  [KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]

scanStage3 ::
  Pat LetDecMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  Imp.TExp Int64 ->
  CrossesSegment ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  CallKernelGen ()
scanStage3 :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TPrimExp Int64 VName
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage3 (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TPrimExp Int64 VName
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans = do
  let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = (SubExp -> TPrimExp Int64 VName)
-> Count NumGroups SubExp -> Count NumGroups (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp Count GroupSize SubExp
group_size
      ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: Shape (TPrimExp Int64 VName)
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp [SubExp]
dims
  TPrimExp Int32 VName
required_groups <-
    SpaceId
-> TPrimExp Int32 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int32 VName)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"required_groups" (TPrimExp Int32 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int32 VName))
-> TPrimExp Int32 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int32 VName)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName -> TPrimExp Int32 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TPrimExp Int32 VName)
-> TPrimExp Int64 VName -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape (TPrimExp Int64 VName)
dims' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size')

  SpaceId
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage3" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups (TPrimExp Int64 VName)
-> Count GroupSize (TPrimExp Int64 VName) -> KernelAttrs
defKernelAttrs Count NumGroups (TPrimExp Int64 VName)
num_groups' Count GroupSize (TPrimExp Int64 VName)
group_size') (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
virtualiseGroups SegVirt
SegVirt TPrimExp Int32 VName
required_groups ((TPrimExp Int32 VName -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int32 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
virt_group_id -> do
      KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

      -- Compute our logical index.
      TPrimExp Int64 VName
flat_idx <-
        SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"flat_idx" (TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virt_group_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size')
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int32 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants)
      (VName
 -> TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> [VName]
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
gtids (Shape (TPrimExp Int64 VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName)
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex Shape (TPrimExp Int64 VName)
dims' TPrimExp Int64 VName
flat_idx

      -- Figure out which group this element was originally in.
      TV Int64
orig_group <- SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"orig_group" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
flat_idx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
elems_per_group
      -- Then the index of the carry-in of the preceding group.
      TV Int64
carry_in_flat_idx <-
        SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"carry_in_flat_idx" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
          TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
orig_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
      -- Figure out the logical index of the carry-in.
      let carry_in_idx :: Shape (TPrimExp Int64 VName)
carry_in_idx = Shape (TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName)
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex Shape (TPrimExp Int64 VName)
dims' (TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx

      -- Apply the carry if we are not in the scan results for the first
      -- group, and are not the last element in such a group (because
      -- then the carry was updated in stage 2), and we are not crossing
      -- a segment boundary.
      let in_bounds :: TExp Bool
in_bounds =
            (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName)
-> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) Shape (TPrimExp Int64 VName)
dims'
          crosses_segment :: TExp Bool
crosses_segment =
            TExp Bool -> Maybe (TExp Bool) -> TExp Bool
forall a. a -> Maybe a -> a
fromMaybe TExp Bool
forall v. TPrimExp Bool v
false (Maybe (TExp Bool) -> TExp Bool) -> Maybe (TExp Bool) -> TExp Bool
forall a b. (a -> b) -> a -> b
$
              CrossesSegment
crossesSegment
                CrossesSegment
-> Maybe (TPrimExp Int64 VName)
-> Maybe (TPrimExp Int64 VName -> TExp Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx)
                Maybe (TPrimExp Int64 VName -> TExp Bool)
-> Maybe (TPrimExp Int64 VName) -> Maybe (TExp Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure TPrimExp Int64 VName
flat_idx
          is_a_carry :: TExp Bool
is_a_carry = TPrimExp Int64 VName
flat_idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
orig_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
          no_carry_in :: TExp Bool
no_carry_in = TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
orig_group TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
is_a_carry TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
crosses_segment

      let per_scan_pes :: [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes = [SegBinOp GPUMem]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes
      TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
          [([PatElem (MemInfo SubExp NoUniqueness MemBind)],
  SegBinOp GPUMem)]
-> (([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp GPUMem)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
-> [SegBinOp GPUMem]
-> [([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp GPUMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElem (MemInfo SubExp NoUniqueness MemBind)]]
per_scan_pes [SegBinOp GPUMem]
scans) ((([PatElem (MemInfo SubExp NoUniqueness MemBind)],
   SegBinOp GPUMem)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     SegBinOp GPUMem)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            \([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes, SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape) -> do
              Maybe (Exp GPUMem)
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall rep inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp GPUMem)
forall a. Maybe a
Nothing (Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope GPUMem)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
              let ([Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params) =
                    Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op

              Shape
-> (Shape (TPrimExp Int64 VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Shape
-> (Shape (TPrimExp Int64 VName) -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest Shape
vec_shape ((Shape (TPrimExp Int64 VName)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (Shape (TPrimExp Int64 VName)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \Shape (TPrimExp Int64 VName)
vec_is -> do
                [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElem (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElem (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElem (MemInfo SubExp NoUniqueness MemBind)
pe) ->
                  VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix
                    (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p)
                    []
                    (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)
                    (Shape (TPrimExp Int64 VName)
carry_in_idx Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ Shape (TPrimExp Int64 VName)
vec_is)

                [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElem (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_y_params [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElem (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElem (MemInfo SubExp NoUniqueness MemBind)
pe) ->
                  VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix
                    (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p)
                    []
                    (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)
                    ((VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ Shape (TPrimExp Int64 VName)
vec_is)

                [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params (Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op

                [(Param (MemInfo SubExp NoUniqueness MemBind),
  PatElem (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
scan_x_params [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) (((Param (MemInfo SubExp NoUniqueness MemBind),
   PatElem (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     PatElem (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, PatElem (MemInfo SubExp NoUniqueness MemBind)
pe) ->
                  VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Shape (TPrimExp Int64 VName)
-> SubExp
-> Shape (TPrimExp Int64 VName)
-> ImpM rep r op ()
copyDWIMFix
                    (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
pe)
                    ((VName -> TPrimExp Int64 VName)
-> [VName] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ Shape (TPrimExp Int64 VName)
vec_is)
                    (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p)
                    []

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat (MemInfo SubExp NoUniqueness MemBind)
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = do
  -- Since stage 2 involves a group size equal to the number of groups
  -- used for stage 1, we have to cap this number to the maximum group
  -- size.
  TV Int64
stage1_max_num_groups <- SpaceId -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall rep r op t. SpaceId -> PrimType -> ImpM rep r op (TV t)
dPrim SpaceId
"stage1_max_num_groups" PrimType
int64
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
stage1_max_num_groups) SizeClass
SizeGroup

  Count NumGroups SubExp
stage1_num_groups <-
    (TV Int64 -> Count NumGroups SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (TV Int64 -> SubExp) -> TV Int64 -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize) (ImpM GPUMem HostEnv HostOp (TV Int64)
 -> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$
      SpaceId
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"stage1_num_groups" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
stage1_max_num_groups) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
Imp.unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl

  (TV Int32
stage1_num_threads, TPrimExp Int64 VName
elems_per_group, CrossesSegment
crossesSegment) <-
    Pat (MemInfo SubExp NoUniqueness MemBind)
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen (TV Int32, TPrimExp Int64 VName, CrossesSegment)
scanStage1 Pat (MemInfo SubExp NoUniqueness MemBind)
pat Count NumGroups SubExp
stage1_num_groups (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ SpaceId -> Maybe Exp -> Code HostOp
forall a. SpaceId -> Maybe Exp -> Code a
Imp.DebugPrint SpaceId
"elems_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
elems_per_group

  Pat (MemInfo SubExp NoUniqueness MemBind)
-> TV Int32
-> TPrimExp Int64 VName
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage2 Pat (MemInfo SubExp NoUniqueness MemBind)
pat TV Int32
stage1_num_threads TPrimExp Int64 VName
elems_per_group Count NumGroups SubExp
stage1_num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans
  Pat (MemInfo SubExp NoUniqueness MemBind)
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TPrimExp Int64 VName
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage3 Pat (MemInfo SubExp NoUniqueness MemBind)
pat (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) TPrimExp Int64 VName
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans