{-# LANGUAGE TypeFamilies #-}

-- | Code generation for segmented and non-segmented scans.  Uses a
-- fairly inefficient two-pass algorithm, but can handle anything.
module Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass (compileSegScan) where

import Control.Monad.Except
import Control.Monad.State
import Data.List (delete, find, foldl', zip4)
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- Aggressively try to reuse memory for different SegBinOps, because
-- we will run them sequentially after another.
makeLocalArrays ::
  Count GroupSize SubExp ->
  SubExp ->
  [SegBinOp GPUMem] ->
  InKernelGen [[VName]]
makeLocalArrays :: Count GroupSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays (Count SubExp
group_size) SubExp
num_threads [SegBinOp GPUMem]
scans = do
  ([[VName]]
arrs, [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems_and_sizes) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPUMem
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [VName]
onScan [SegBinOp GPUMem]
scans) forall a. Monoid a => a
mempty
  let maxSize :: [Count u (TPrimExp Int64 v)] -> Count Bytes (TPrimExp Int64 v)
maxSize [Count u (TPrimExp Int64 v)]
sizes = forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 v
1 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (u :: k) e. Count u e -> e
Imp.unCount [Count u (TPrimExp Int64 v)]
sizes
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems_and_sizes forall a b. (a -> b) -> a -> b
$ \([Count Bytes (TPrimExp Int64 VName)]
sizes, VName
mem) ->
    forall {k} (rep :: k) r op.
VName
-> Count Bytes (TPrimExp Int64 VName) -> Space -> ImpM rep r op ()
sAlloc_ VName
mem (forall {k} {v} {u :: k}.
Pretty v =>
[Count u (TPrimExp Int64 v)] -> Count Bytes (TPrimExp Int64 v)
maxSize [Count Bytes (TPrimExp Int64 VName)]
sizes) (SpaceId -> Space
Space SpaceId
"local")
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [[VName]]
arrs
  where
    onScan :: SegBinOp GPUMem
-> StateT
     [([Count Bytes (TPrimExp Int64 VName)], VName)]
     (ImpM GPUMem KernelEnv KernelOp)
     [VName]
onScan (SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
_) = do
      let ([Param LParamMem]
scan_x_params, [Param LParamMem]
_scan_y_params) =
            forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
      ([VName]
arrs, [[([Count Bytes (TPrimExp Int64 VName)], VName)]]
used_mems) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LParamMem]
scan_x_params forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p ->
          case forall dec. Param dec -> dec
paramDec Param LParamMem
p of
            MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
              let shape' :: Shape
shape' = forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] forall a. Semigroup a => a -> a -> a
<> Shape
shape
              VName
arr <-
                forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op.
SpaceId
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray SpaceId
"scan_arr" PrimType
pt Shape
shape' VName
mem forall a b. (a -> b) -> a -> b
$
                  forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$
                    forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$
                      forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
arr, [])
            LParamMem
_ -> do
              let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p
                  shape :: Shape
shape = forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
              ([Count Bytes (TPrimExp Int64 VName)]
sizes, VName
mem') <- forall {k} {t :: (* -> *) -> * -> *} {rep :: k} {r} {op}.
(MonadState
   [([Count Bytes (TPrimExp Int64 VName)], VName)]
   (t (ImpM rep r op)),
 MonadTrans t) =>
PrimType
-> Shape
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
getMem PrimType
pt Shape
shape
              VName
arr <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
SpaceId -> PrimType -> Shape -> VName -> ImpM rep r op VName
sArrayInMem SpaceId
"scan_arr" PrimType
pt Shape
shape VName
mem'
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
arr, [([Count Bytes (TPrimExp Int64 VName)]
sizes, VName
mem')])
      forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[([Count Bytes (TPrimExp Int64 VName)], VName)]]
used_mems)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
arrs

    getMem :: PrimType
-> Shape
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
getMem PrimType
pt Shape
shape = do
      let size :: Count Bytes (TPrimExp Int64 VName)
size = Type -> Count Bytes (TPrimExp Int64 VName)
typeSize forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness
      [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems <- forall s (m :: * -> *). MonadState s m => m s
get
      case (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Count Bytes (TPrimExp Int64 VName)
size `elem`) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems, [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems) of
        (Just ([Count Bytes (TPrimExp Int64 VName)], VName)
mem, [([Count Bytes (TPrimExp Int64 VName)], VName)]
_) -> do
          forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall a. Eq a => a -> [a] -> [a]
delete ([Count Bytes (TPrimExp Int64 VName)], VName)
mem
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Count Bytes (TPrimExp Int64 VName)], VName)
mem
        (Maybe ([Count Bytes (TPrimExp Int64 VName)], VName)
Nothing, ([Count Bytes (TPrimExp Int64 VName)]
size', VName
mem) : [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems') -> do
          forall s (m :: * -> *). MonadState s m => s -> m ()
put [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems'
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Count Bytes (TPrimExp Int64 VName)
size forall a. a -> [a] -> [a]
: [Count Bytes (TPrimExp Int64 VName)]
size', VName
mem)
        (Maybe ([Count Bytes (TPrimExp Int64 VName)], VName)
Nothing, []) -> do
          VName
mem <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op. SpaceId -> Space -> ImpM rep r op VName
sDeclareMem SpaceId
"scan_arr_mem" forall a b. (a -> b) -> a -> b
$ SpaceId -> Space
Space SpaceId
"local"
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Count Bytes (TPrimExp Int64 VName)
size], VName
mem)

type CrossesSegment = Maybe (Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Bool)

localArrayIndex :: KernelConstants -> Type -> Imp.TExp Int64
localArrayIndex :: KernelConstants -> Type -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants Type
t =
  if forall shape u. TypeBase shape u -> Bool
primType Type
t
    then forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants)
    else forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelGlobalThreadId KernelConstants
constants)

barrierFor :: Lambda GPUMem -> (Bool, Imp.Fence, InKernelGen ())
barrierFor :: Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op = (Bool
array_scan, Fence
fence, forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence)
  where
    array_scan :: Bool
array_scan = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPUMem
scan_op
    fence :: Fence
fence
      | Bool
array_scan = Fence
Imp.FenceGlobal
      | Bool
otherwise = Fence
Imp.FenceLocal

xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
  forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
  forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))

writeToScanValues ::
  [VName] ->
  ([PatElem LetDecMem], SegBinOp GPUMem, [KernelResult]) ->
  InKernelGen ()
writeToScanValues :: [VName]
-> ([PatElem LParamMem], SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids ([PatElem LParamMem]
pes, SegBinOp GPUMem
scan, [KernelResult]
scan_res)
  | forall a. ArrayShape a => a -> Int
shapeRank (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) forall a. Ord a => a -> a -> Bool
> Int
0 =
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
pes [KernelResult]
scan_res) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
        forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
          (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
          (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids)
          (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
          []
  | Bool
otherwise =
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) [KernelResult]
scan_res) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
res) ->
        forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

readToScanValues ::
  [Imp.TExp Int64] ->
  [PatElem LetDecMem] ->
  SegBinOp GPUMem ->
  InKernelGen ()
readToScanValues :: [TPrimExp Int64 VName]
-> [PatElem LParamMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readToScanValues [TPrimExp Int64 VName]
is [PatElem LParamMem]
pes SegBinOp GPUMem
scan
  | forall a. ArrayShape a => a -> Int
shapeRank (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) forall a. Ord a => a -> a -> Bool
> Int
0 =
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) [PatElem LParamMem]
pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
        forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)) [TPrimExp Int64 VName]
is
  | Bool
otherwise =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

readCarries ::
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  [Imp.TExp Int64] ->
  [Imp.TExp Int64] ->
  [PatElem LetDecMem] ->
  SegBinOp GPUMem ->
  InKernelGen ()
readCarries :: TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [PatElem LParamMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readCarries TPrimExp Int64 VName
chunk_id TPrimExp Int64 VName
chunk_offset [TPrimExp Int64 VName]
dims' [TPrimExp Int64 VName]
vec_is [PatElem LParamMem]
pes SegBinOp GPUMem
scan
  | forall a. ArrayShape a => a -> Int
shapeRank (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) forall a. Ord a => a -> a -> Bool
> Int
0 = do
      TPrimExp Int32 VName
ltid <- KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
      -- We may have to reload the carries from the output of the
      -- previous chunk.
      forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        (TPrimExp Int64 VName
chunk_id forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 VName
0 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int32 VName
ltid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0)
        ( do
            let is :: [TPrimExp Int64 VName]
is = forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
dims' forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
chunk_offset forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) [PatElem LParamMem]
pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
              forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)) ([TPrimExp Int64 VName]
is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
        )
        ( forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
            forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
        )
  | Bool
otherwise =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Produce partially scanned intervals; one per workgroup.
scanStage1 ::
  Pat LetDecMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen (TV Int32, Imp.TExp Int64, CrossesSegment)
scanStage1 :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen (TV Int32, TPrimExp Int64 VName, CrossesSegment)
scanStage1 (Pat [PatElem LParamMem]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = do
  let num_groups' :: Count NumGroups (TPrimExp Int64 VName)
num_groups' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count GroupSize SubExp
group_size
  TV Int32
num_threads <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"num_threads" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups (TPrimExp Int64 VName)
num_groups' forall a. Num a => a -> a -> a
* forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'

  let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
  let num_elements :: TPrimExp Int64 VName
num_elements = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      elems_per_thread :: TPrimExp Int64 VName
elems_per_thread = TPrimExp Int64 VName
num_elements forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_threads)
      elems_per_group :: TPrimExp Int64 VName
elems_per_group = forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size' forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_thread

  let crossesSegment :: CrossesSegment
crossesSegment =
        case forall a. [a] -> [a]
reverse [TPrimExp Int64 VName]
dims' of
          TPrimExp Int64 VName
segment_size : TPrimExp Int64 VName
_ : [TPrimExp Int64 VName]
_ -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
from TPrimExp Int64 VName
to ->
            (TPrimExp Int64 VName
to forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
from) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TPrimExp Int64 VName
to forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)
          [TPrimExp Int64 VName]
_ -> forall a. Maybe a
Nothing

  SpaceId
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage1" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
    [[VName]]
all_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_threads) [SegBinOp GPUMem]
scans

    -- The variables from scan_op will be used for the carry and such
    -- in the big chunking loop.
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOp GPUMem]
scans forall a b. (a -> b) -> a -> b
$ \SegBinOp GPUMem
scan -> do
      forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
        forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []

    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor SpaceId
"j" TPrimExp Int64 VName
elems_per_thread forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
j -> do
      TV Int64
chunk_offset <-
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"chunk_offset" forall a b. (a -> b) -> a -> b
$
          forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
j
            forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants) forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group
      TV Int64
flat_idx <-
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"flat_idx" forall a b. (a -> b) -> a -> b
$
          forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants)
      -- Construct segment indices.
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
gtids forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
dims' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
flat_idx

      let per_scan_pes :: [[PatElem LParamMem]]
per_scan_pes = forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem LParamMem]
all_pes

          in_bounds :: TExp Bool
in_bounds =
            forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) [TPrimExp Int64 VName]
dims'

          when_in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds = forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$ do
            let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
                  forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
scans) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
                per_scan_res :: [[KernelResult]]
per_scan_res =
                  forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [KernelResult]
all_scan_res

            forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write to-scan values to parameters" forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([VName]
-> ([PatElem LParamMem], SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids) forall a b. (a -> b) -> a -> b
$
                forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_scan_pes [SegBinOp GPUMem]
scans [[KernelResult]]
per_scan_res

            forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write mapped values results to global memory" forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem LParamMem]
all_pes) [KernelResult]
map_res) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
se) ->
                forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                  (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                  (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids)
                  (KernelResult -> SubExp
kernelResultSubExp KernelResult
se)
                  []

      forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"threads in bounds read input" forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds

      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape) [SegBinOp GPUMem]
scans) forall a b. (a -> b) -> a -> b
$
        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$
          Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_scan_pes [SegBinOp GPUMem]
scans [[VName]]
all_local_arrs) forall a b. (a -> b) -> a -> b
$
        \([PatElem LParamMem]
pes, scan :: SegBinOp GPUMem
scan@(SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape), [VName]
local_arrs) ->
          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"do one intra-group scan operation" forall a b. (a -> b) -> a -> b
$ do
            let rets :: [Type]
rets = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPUMem
scan_op
                scan_x_params :: [LParam GPUMem]
scan_x_params = SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan
                (Bool
array_scan, Fence
fence, ImpM GPUMem KernelEnv KernelOp ()
barrier) = Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op

            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier

            forall {k} (rep :: k) r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
vec_shape forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
              forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"maybe restore some to-scan values to parameters, or read neutral" forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                  TExp Bool
in_bounds
                  ( do
                      [TPrimExp Int64 VName]
-> [PatElem LParamMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readToScanValues (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) [PatElem LParamMem]
pes SegBinOp GPUMem
scan
                      TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [PatElem LParamMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readCarries TPrimExp Int64 VName
j (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset) [TPrimExp Int64 VName]
dims' [TPrimExp Int64 VName]
vec_is [PatElem LParamMem]
pes SegBinOp GPUMem
scan
                  )
                  ( forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
                      forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
                  )

              forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"combine with carry and write to local memory" forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op) forall a b. (a -> b) -> a -> b
$
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
rets [VName]
local_arrs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op) forall a b. (a -> b) -> a -> b
$
                    \(Type
t, VName
arr, SubExp
se) ->
                      forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [KernelConstants -> Type -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants Type
t] SubExp
se []

              let crossesSegment' :: Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment' = do
                    TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f <- CrossesSegment
crossesSegment
                    forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
from TPrimExp Int32 VName
to ->
                      let from' :: TPrimExp Int64 VName
from' = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from forall a. Num a => a -> a -> a
+ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset
                          to' :: TPrimExp Int64 VName
to' = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to forall a. Num a => a -> a -> a
+ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset
                       in TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f TPrimExp Int64 VName
from' TPrimExp Int64 VName
to'

              forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence

              -- We need to avoid parameter name clashes.
              Lambda GPUMem
scan_op_renamed <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scan_op
              Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
                Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment'
                (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_threads)
                (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
                Lambda GPUMem
scan_op_renamed
                [VName]
local_arrs

              forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"threads in bounds write partial scan result" forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds forall a b. (a -> b) -> a -> b
$
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
rets [PatElem LParamMem]
pes [VName]
local_arrs) forall a b. (a -> b) -> a -> b
$ \(Type
t, PatElem LParamMem
pe, VName
arr) ->
                    forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                      (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                      (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
                      (VName -> SubExp
Var VName
arr)
                      [KernelConstants -> Type -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants Type
t]

              ImpM GPUMem KernelEnv KernelOp ()
barrier

              let load_carry :: ImpM GPUMem KernelEnv KernelOp ()
load_carry =
                    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_arrs [LParam GPUMem]
scan_x_params) forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LParamMem
p) ->
                      forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                        (forall dec. Param dec -> VName
paramName Param LParamMem
p)
                        []
                        (VName -> SubExp
Var VName
arr)
                        [ if forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p
                            then forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
                            else
                              (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants) forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
                                forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
                                forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
                        ]
                  load_neutral :: ImpM GPUMem KernelEnv KernelOp ()
load_neutral =
                    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [LParam GPUMem]
scan_x_params) forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, Param LParamMem
p) ->
                      forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []

              forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"first thread reads last element as carry-in for next iteration" forall a b. (a -> b) -> a -> b
$ do
                TExp Bool
crosses_segment <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"crosses_segment" forall a b. (a -> b) -> a -> b
$
                  case CrossesSegment
crossesSegment of
                    CrossesSegment
Nothing -> forall v. TPrimExp Bool v
false
                    Just TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f ->
                      TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f
                        ( forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset
                            forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
                            forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
                        )
                        ( forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset
                            forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
                        )
                TExp Bool
should_load_carry <-
                  forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"should_load_carry" forall a b. (a -> b) -> a -> b
$
                    KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
crosses_segment
                forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
should_load_carry ImpM GPUMem KernelEnv KernelOp ()
load_carry
                forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier
                forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
should_load_carry ImpM GPUMem KernelEnv KernelOp ()
load_neutral

              ImpM GPUMem KernelEnv KernelOp ()
barrier

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV Int32
num_threads, TPrimExp Int64 VName
elems_per_group, CrossesSegment
crossesSegment)

scanStage2 ::
  Pat LetDecMem ->
  TV Int32 ->
  Imp.TExp Int64 ->
  Count NumGroups SubExp ->
  CrossesSegment ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  CallKernelGen ()
scanStage2 :: Pat LParamMem
-> TV Int32
-> TPrimExp Int64 VName
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage2 (Pat [PatElem LParamMem]
all_pes) TV Int32
stage1_num_threads TPrimExp Int64 VName
elems_per_group Count NumGroups SubExp
num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans = do
  let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims

  -- Our group size is the number of groups for the stage 1 kernel.
  let group_size :: Count GroupSize SubExp
group_size = forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups SubExp
num_groups

  let crossesSegment' :: Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment' = do
        TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f <- CrossesSegment
crossesSegment
        forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
from TPrimExp Int32 VName
to ->
          TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f
            ((forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
            ((forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)

  SpaceId
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage2" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs (forall {k} (u :: k) e. e -> Count u e
Count (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)) Count GroupSize SubExp
group_size) forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
    [[VName]]
per_scan_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
stage1_num_threads) [SegBinOp GPUMem]
scans
    let per_scan_rets :: [[Type]]
per_scan_rets = forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
scans
        per_scan_pes :: [[PatElem LParamMem]]
per_scan_pes = forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem LParamMem]
all_pes

    TV Int64
flat_idx <-
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"flat_idx" forall a b. (a -> b) -> a -> b
$
        (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants) forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
    -- Construct segment indices.
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
gtids forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
dims' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
flat_idx

    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SegBinOp GPUMem]
scans [[VName]]
per_scan_local_arrs [[Type]]
per_scan_rets [[PatElem LParamMem]]
per_scan_pes) forall a b. (a -> b) -> a -> b
$
      \(SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape, [VName]
local_arrs, [Type]
rets, [PatElem LParamMem]
pes) ->
        forall {k} (rep :: k) r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
vec_shape forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
          let glob_is :: [TPrimExp Int64 VName]
glob_is = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is

              in_bounds :: TExp Bool
in_bounds =
                forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) [TPrimExp Int64 VName]
dims'

              when_in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
rets [VName]
local_arrs [PatElem LParamMem]
pes) forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
arr, PatElem LParamMem
pe) ->
                forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                  VName
arr
                  [KernelConstants -> Type -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants Type
t]
                  (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                  [TPrimExp Int64 VName]
glob_is

              when_out_of_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_out_of_bounds = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
rets [VName]
local_arrs [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
arr, SubExp
ne) ->
                forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [KernelConstants -> Type -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants Type
t] SubExp
ne []
              (Bool
_, Fence
_, ImpM GPUMem KernelEnv KernelOp ()
barrier) =
                Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op

          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"threads in bound read carries; others get neutral element" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_out_of_bounds

          ImpM GPUMem KernelEnv KernelOp ()
barrier

          Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
            Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment'
            (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
stage1_num_threads)
            (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants)
            Lambda GPUMem
scan_op
            [VName]
local_arrs

          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"threads in bounds write scanned carries" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
rets [PatElem LParamMem]
pes [VName]
local_arrs) forall a b. (a -> b) -> a -> b
$ \(Type
t, PatElem LParamMem
pe, VName
arr) ->
                forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                  (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                  [TPrimExp Int64 VName]
glob_is
                  (VName -> SubExp
Var VName
arr)
                  [KernelConstants -> Type -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants Type
t]

scanStage3 ::
  Pat LetDecMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  Imp.TExp Int64 ->
  CrossesSegment ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  CallKernelGen ()
scanStage3 :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TPrimExp Int64 VName
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage3 (Pat [PatElem LParamMem]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TPrimExp Int64 VName
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans = do
  let group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count GroupSize SubExp
group_size
      ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
  TPrimExp Int32 VName
required_groups <-
    forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"required_groups" forall a b. (a -> b) -> a -> b
$
      forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims' forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size')

  SpaceId
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage3" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) forall a b. (a -> b) -> a -> b
$
    SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
virtualiseGroups SegVirt
SegVirt TPrimExp Int32 VName
required_groups forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
virt_group_id -> do
      KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv

      -- Compute our logical index.
      TPrimExp Int64 VName
flat_idx <-
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"flat_idx" forall a b. (a -> b) -> a -> b
$
          forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virt_group_id forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size')
            forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants)
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
gtids forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
flat_idx

      -- Figure out which group this element was originally in.
      TV Int64
orig_group <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"orig_group" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
flat_idx forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
elems_per_group
      -- Then the index of the carry-in of the preceding group.
      TV Int64
carry_in_flat_idx <-
        forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"carry_in_flat_idx" forall a b. (a -> b) -> a -> b
$
          forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
orig_group forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
      -- Figure out the logical index of the carry-in.
      let carry_in_idx :: [TPrimExp Int64 VName]
carry_in_idx = forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
dims' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx

      -- Apply the carry if we are not in the scan results for the first
      -- group, and are not the last element in such a group (because
      -- then the carry was updated in stage 2), and we are not crossing
      -- a segment boundary.
      let in_bounds :: TExp Bool
in_bounds =
            forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) [TPrimExp Int64 VName]
dims'
          crosses_segment :: TExp Bool
crosses_segment =
            forall a. a -> Maybe a -> a
fromMaybe forall v. TPrimExp Bool v
false forall a b. (a -> b) -> a -> b
$
              CrossesSegment
crossesSegment
                forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx)
                forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure TPrimExp Int64 VName
flat_idx
          is_a_carry :: TExp Bool
is_a_carry = TPrimExp Int64 VName
flat_idx forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
orig_group forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
          no_carry_in :: TExp Bool
no_carry_in = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
orig_group forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
is_a_carry forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
crosses_segment

      let per_scan_pes :: [[PatElem LParamMem]]
per_scan_pes = forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem LParamMem]
all_pes
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElem LParamMem]]
per_scan_pes [SegBinOp GPUMem]
scans) forall a b. (a -> b) -> a -> b
$
            \([PatElem LParamMem]
pes, SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape) -> do
              forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
              let ([Param LParamMem]
scan_x_params, [Param LParamMem]
scan_y_params) =
                    forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op

              forall {k} (rep :: k) r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
vec_shape forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params [PatElem LParamMem]
pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                    (forall dec. Param dec -> VName
paramName Param LParamMem
p)
                    []
                    (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                    ([TPrimExp Int64 VName]
carry_in_idx forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)

                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_y_params [PatElem LParamMem]
pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                    (forall dec. Param dec -> VName
paramName Param LParamMem
p)
                    []
                    (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                    (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)

                forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
scan_x_params forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op

                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params [PatElem LParamMem]
pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
                    (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
                    (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
                    (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p)
                    []

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = do
  KernelAttrs
attrs <- SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl

  -- Since stage 2 involves a group size equal to the number of groups
  -- used for stage 1, we have to cap this number to the maximum group
  -- size.
  TV Int64
stage1_max_num_groups <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
SpaceId -> PrimType -> ImpM rep r op (TV t)
dPrim SpaceId
"stage1_max_num_groups" PrimType
int64
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
stage1_max_num_groups) SizeClass
SizeGroup

  Count NumGroups SubExp
stage1_num_groups <-
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (u :: k) e. e -> Count u e
Imp.Count forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k). TV t -> SubExp
tvSize) forall a b. (a -> b) -> a -> b
$
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"stage1_num_groups" forall a b. (a -> b) -> a -> b
$
        forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
stage1_max_num_groups) forall a b. (a -> b) -> a -> b
$
          SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (u :: k) e. Count u e -> e
Imp.unCount forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelAttrs -> Count NumGroups SubExp
kAttrNumGroups forall a b. (a -> b) -> a -> b
$
            KernelAttrs
attrs

  (TV Int32
stage1_num_threads, TPrimExp Int64 VName
elems_per_group, CrossesSegment
crossesSegment) <-
    Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen (TV Int32, TPrimExp Int64 VName, CrossesSegment)
scanStage1 Pat LParamMem
pat Count NumGroups SubExp
stage1_num_groups (KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize KernelAttrs
attrs) SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody

  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. SpaceId -> Maybe Exp -> Code a
Imp.DebugPrint SpaceId
"elems_per_group" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
elems_per_group

  Pat LParamMem
-> TV Int32
-> TPrimExp Int64 VName
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage2 Pat LParamMem
pat TV Int32
stage1_num_threads TPrimExp Int64 VName
elems_per_group Count NumGroups SubExp
stage1_num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans
  Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TPrimExp Int64 VName
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage3 Pat LParamMem
pat (KernelAttrs -> Count NumGroups SubExp
kAttrNumGroups KernelAttrs
attrs) (KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize KernelAttrs
attrs) TPrimExp Int64 VName
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans