{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass (compileSegScan) where
import Control.Monad.Except
import Control.Monad.State
import Data.List (delete, find, foldl', zip4)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
makeLocalArrays ::
Count GroupSize SubExp ->
SubExp ->
[SegBinOp GPUMem] ->
InKernelGen [[VName]]
makeLocalArrays :: Count GroupSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays (Count SubExp
group_size) SubExp
num_threads [SegBinOp GPUMem]
scans = do
([[VName]]
arrs, [([Count Bytes (TExp Int64)], VName)]
mems_and_sizes) <- StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[[VName]]
-> [([Count Bytes (TExp Int64)], VName)]
-> ImpM
GPUMem
KernelEnv
KernelOp
([[VName]], [([Count Bytes (TExp Int64)], VName)])
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((SegBinOp GPUMem
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[VName])
-> [SegBinOp GPUMem]
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPUMem
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[VName]
onScan [SegBinOp GPUMem]
scans) [([Count Bytes (TExp Int64)], VName)]
forall a. Monoid a => a
mempty
let maxSize :: [Count u (TPrimExp Int64 v)] -> Count Bytes (TPrimExp Int64 v)
maxSize [Count u (TPrimExp Int64 v)]
sizes = TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v))
-> TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v)
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v)
-> TPrimExp Int64 v -> [TPrimExp Int64 v] -> TPrimExp Int64 v
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 v
1 ([TPrimExp Int64 v] -> TPrimExp Int64 v)
-> [TPrimExp Int64 v] -> TPrimExp Int64 v
forall a b. (a -> b) -> a -> b
$ (Count u (TPrimExp Int64 v) -> TPrimExp Int64 v)
-> [Count u (TPrimExp Int64 v)] -> [TPrimExp Int64 v]
forall a b. (a -> b) -> [a] -> [b]
map Count u (TPrimExp Int64 v) -> TPrimExp Int64 v
forall u e. Count u e -> e
Imp.unCount [Count u (TPrimExp Int64 v)]
sizes
[([Count Bytes (TExp Int64)], VName)]
-> (([Count Bytes (TExp Int64)], VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([Count Bytes (TExp Int64)], VName)]
mems_and_sizes ((([Count Bytes (TExp Int64)], VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (([Count Bytes (TExp Int64)], VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \([Count Bytes (TExp Int64)]
sizes, VName
mem) ->
VName
-> Count Bytes (TExp Int64)
-> Space
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
mem ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall {v} {u}.
Pretty v =>
[Count u (TPrimExp Int64 v)] -> Count Bytes (TPrimExp Int64 v)
maxSize [Count Bytes (TExp Int64)]
sizes) (SpaceId -> Space
Space SpaceId
"local")
[[VName]] -> InKernelGen [[VName]]
forall (m :: * -> *) a. Monad m => a -> m a
return [[VName]]
arrs
where
onScan :: SegBinOp GPUMem
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[VName]
onScan (SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
_) = do
let ([Param LParamMem]
scan_x_params, [Param LParamMem]
_scan_y_params) =
Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
([VName]
arrs, [[([Count Bytes (TExp Int64)], VName)]]
used_mems) <- ([(VName, [([Count Bytes (TExp Int64)], VName)])]
-> ([VName], [[([Count Bytes (TExp Int64)], VName)]]))
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TExp Int64)], VName)])]
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
([VName], [[([Count Bytes (TExp Int64)], VName)]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, [([Count Bytes (TExp Int64)], VName)])]
-> ([VName], [[([Count Bytes (TExp Int64)], VName)]])
forall a b. [(a, b)] -> ([a], [b])
unzip (StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TExp Int64)], VName)])]
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
([VName], [[([Count Bytes (TExp Int64)], VName)]]))
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TExp Int64)], VName)])]
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
([VName], [[([Count Bytes (TExp Int64)], VName)]])
forall a b. (a -> b) -> a -> b
$
[Param LParamMem]
-> (Param LParamMem
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
(VName, [([Count Bytes (TExp Int64)], VName)]))
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TExp Int64)], VName)])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LParamMem]
scan_x_params ((Param LParamMem
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
(VName, [([Count Bytes (TExp Int64)], VName)]))
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TExp Int64)], VName)])])
-> (Param LParamMem
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
(VName, [([Count Bytes (TExp Int64)], VName)]))
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TExp Int64)], VName)])]
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p ->
case Param LParamMem -> LParamMem
forall dec. Param dec -> dec
paramDec Param LParamMem
p of
MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
VName
arr <-
ImpM GPUMem KernelEnv KernelOp VName
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM GPUMem KernelEnv KernelOp VName
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName)
-> ImpM GPUMem KernelEnv KernelOp VName
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName
forall a b. (a -> b) -> a -> b
$
SpaceId
-> PrimType
-> Shape
-> MemBind
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
SpaceId -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
sArray SpaceId
"scan_arr" PrimType
pt Shape
shape' (MemBind -> ImpM GPUMem KernelEnv KernelOp VName)
-> MemBind -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
(VName, [([Count Bytes (TExp Int64)], VName)])
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
(VName, [([Count Bytes (TExp Int64)], VName)])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, [])
LParamMem
_ -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
([Count Bytes (TExp Int64)]
sizes, VName
mem') <- PrimType
-> Shape
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
([Count Bytes (TExp Int64)], VName)
forall {t :: (* -> *) -> * -> *} {rep} {r} {op}.
(MonadState
[([Count Bytes (TExp Int64)], VName)] (t (ImpM rep r op)),
MonadTrans t) =>
PrimType
-> Shape -> t (ImpM rep r op) ([Count Bytes (TExp Int64)], VName)
getMem PrimType
pt Shape
shape
VName
arr <- ImpM GPUMem KernelEnv KernelOp VName
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM GPUMem KernelEnv KernelOp VName
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName)
-> ImpM GPUMem KernelEnv KernelOp VName
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName
forall a b. (a -> b) -> a -> b
$ SpaceId
-> PrimType
-> Shape
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
SpaceId -> PrimType -> Shape -> VName -> ImpM rep r op VName
sArrayInMem SpaceId
"scan_arr" PrimType
pt Shape
shape VName
mem'
(VName, [([Count Bytes (TExp Int64)], VName)])
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
(VName, [([Count Bytes (TExp Int64)], VName)])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, [([Count Bytes (TExp Int64)]
sizes, VName
mem')])
([([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)])
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)]
forall a. Semigroup a => a -> a -> a
<> [[([Count Bytes (TExp Int64)], VName)]]
-> [([Count Bytes (TExp Int64)], VName)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[([Count Bytes (TExp Int64)], VName)]]
used_mems)
[VName]
-> StateT
[([Count Bytes (TExp Int64)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
arrs
getMem :: PrimType
-> Shape -> t (ImpM rep r op) ([Count Bytes (TExp Int64)], VName)
getMem PrimType
pt Shape
shape = do
let size :: Count Bytes (TExp Int64)
size = TypeBase Shape NoUniqueness -> Count Bytes (TExp Int64)
typeSize (TypeBase Shape NoUniqueness -> Count Bytes (TExp Int64))
-> TypeBase Shape NoUniqueness -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness
[([Count Bytes (TExp Int64)], VName)]
mems <- t (ImpM rep r op) [([Count Bytes (TExp Int64)], VName)]
forall s (m :: * -> *). MonadState s m => m s
get
case ((([Count Bytes (TExp Int64)], VName) -> Bool)
-> [([Count Bytes (TExp Int64)], VName)]
-> Maybe ([Count Bytes (TExp Int64)], VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Count Bytes (TExp Int64)
size Count Bytes (TExp Int64) -> [Count Bytes (TExp Int64)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem`) ([Count Bytes (TExp Int64)] -> Bool)
-> (([Count Bytes (TExp Int64)], VName)
-> [Count Bytes (TExp Int64)])
-> ([Count Bytes (TExp Int64)], VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Count Bytes (TExp Int64)], VName) -> [Count Bytes (TExp Int64)]
forall a b. (a, b) -> a
fst) [([Count Bytes (TExp Int64)], VName)]
mems, [([Count Bytes (TExp Int64)], VName)]
mems) of
(Just ([Count Bytes (TExp Int64)], VName)
mem, [([Count Bytes (TExp Int64)], VName)]
_) -> do
([([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)])
-> t (ImpM rep r op) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (([([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)])
-> t (ImpM rep r op) ())
-> ([([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)])
-> t (ImpM rep r op) ()
forall a b. (a -> b) -> a -> b
$ ([Count Bytes (TExp Int64)], VName)
-> [([Count Bytes (TExp Int64)], VName)]
-> [([Count Bytes (TExp Int64)], VName)]
forall a. Eq a => a -> [a] -> [a]
delete ([Count Bytes (TExp Int64)], VName)
mem
([Count Bytes (TExp Int64)], VName)
-> t (ImpM rep r op) ([Count Bytes (TExp Int64)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Count Bytes (TExp Int64)], VName)
mem
(Maybe ([Count Bytes (TExp Int64)], VName)
Nothing, ([Count Bytes (TExp Int64)]
size', VName
mem) : [([Count Bytes (TExp Int64)], VName)]
mems') -> do
[([Count Bytes (TExp Int64)], VName)] -> t (ImpM rep r op) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [([Count Bytes (TExp Int64)], VName)]
mems'
([Count Bytes (TExp Int64)], VName)
-> t (ImpM rep r op) ([Count Bytes (TExp Int64)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Count Bytes (TExp Int64)
size Count Bytes (TExp Int64)
-> [Count Bytes (TExp Int64)] -> [Count Bytes (TExp Int64)]
forall a. a -> [a] -> [a]
: [Count Bytes (TExp Int64)]
size', VName
mem)
(Maybe ([Count Bytes (TExp Int64)], VName)
Nothing, []) -> do
VName
mem <- ImpM rep r op VName -> t (ImpM rep r op) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM rep r op VName -> t (ImpM rep r op) VName)
-> ImpM rep r op VName -> t (ImpM rep r op) VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space -> ImpM rep r op VName
forall rep r op. SpaceId -> Space -> ImpM rep r op VName
sDeclareMem SpaceId
"scan_arr_mem" (Space -> ImpM rep r op VName) -> Space -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space
Space SpaceId
"local"
([Count Bytes (TExp Int64)], VName)
-> t (ImpM rep r op) ([Count Bytes (TExp Int64)], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Count Bytes (TExp Int64)
size], VName
mem)
type CrossesSegment = Maybe (Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Bool)
localArrayIndex :: KernelConstants -> Type -> Imp.TExp Int64
localArrayIndex :: KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t =
if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t
then TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants)
else TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelGlobalThreadId KernelConstants
constants)
barrierFor :: Lambda GPUMem -> (Bool, Imp.Fence, InKernelGen ())
barrierFor :: Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op = (Bool
array_scan, Fence
fence, KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence)
where
array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_op
fence :: Fence
fence
| Bool
array_scan = Fence
Imp.FenceGlobal
| Bool
otherwise = Fence
Imp.FenceLocal
xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
writeToScanValues ::
[VName] ->
([PatElem GPUMem], SegBinOp GPUMem, [KernelResult]) ->
InKernelGen ()
writeToScanValues :: [VName]
-> ([PatElem GPUMem], SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids ([PatElem GPUMem]
pes, SegBinOp GPUMem
scan, [KernelResult]
scan_res)
| Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
[(PatElemT LParamMem, KernelResult)]
-> ((PatElemT LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [KernelResult] -> [(PatElemT LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem GPUMem]
[PatElemT LParamMem]
pes [KernelResult]
scan_res) (((PatElemT LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElemT LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, KernelResult
res) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
[]
| Bool
otherwise =
[(Param LParamMem, KernelResult)]
-> ((Param LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [KernelResult] -> [(Param LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) [KernelResult]
scan_res) (((Param LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
res) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
readToScanValues ::
[Imp.TExp Int64] ->
[PatElem GPUMem] ->
SegBinOp GPUMem ->
InKernelGen ()
readToScanValues :: [TExp Int64]
-> [PatElem GPUMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readToScanValues [TExp Int64]
is [PatElem GPUMem]
pes SegBinOp GPUMem
scan
| Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
[(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) [PatElem GPUMem]
[PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)) [TExp Int64]
is
| Bool
otherwise =
() -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
readCarries ::
Imp.TExp Int64 ->
[Imp.TExp Int64] ->
[Imp.TExp Int64] ->
[PatElem GPUMem] ->
SegBinOp GPUMem ->
InKernelGen ()
readCarries :: TExp Int64
-> [TExp Int64]
-> [TExp Int64]
-> [PatElem GPUMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readCarries TExp Int64
chunk_offset [TExp Int64]
dims' [TExp Int64]
vec_is [PatElem GPUMem]
pes SegBinOp GPUMem
scan
| Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
TPrimExp Int32 ExpLeaf
ltid <- KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId (KernelConstants -> TPrimExp Int32 ExpLeaf)
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> TPrimExp Int32 ExpLeaf
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int32 ExpLeaf)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 ExpLeaf)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int64
chunk_offset TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int64
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int32 ExpLeaf
ltid TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0)
( do
let is :: [TExp Int64]
is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int64
chunk_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
[(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) [PatElem GPUMem]
[PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)) ([TExp Int64]
is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
)
( [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
)
| Bool
otherwise =
() -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
scanStage1 ::
Pattern GPUMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
SegSpace ->
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen (TV Int32, Imp.TExp Int64, CrossesSegment)
scanStage1 :: Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen (TV Int32, TExp Int64, CrossesSegment)
scanStage1 (Pattern [PatElem GPUMem]
_ [PatElem GPUMem]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = do
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
TV Int32
num_threads <- SpaceId
-> TPrimExp Int32 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"num_threads" (TPrimExp Int32 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int32))
-> TPrimExp Int32 ExpLeaf -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TPrimExp Int32 ExpLeaf)
-> TExp Int64 -> TPrimExp Int32 ExpLeaf
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
let num_elements :: TExp Int64
num_elements = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims'
elems_per_thread :: TExp Int64
elems_per_thread = TExp Int64
num_elements TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
num_threads)
elems_per_group :: TExp Int64
elems_per_group = Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_thread
let crossesSegment :: CrossesSegment
crossesSegment =
case [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
reverse [TExp Int64]
dims' of
TExp Int64
segment_size : TExp Int64
_ : [TExp Int64]
_ -> (TExp Int64 -> TExp Int64 -> TExp Bool) -> CrossesSegment
forall a. a -> Maybe a
Just ((TExp Int64 -> TExp Int64 -> TExp Bool) -> CrossesSegment)
-> (TExp Int64 -> TExp Int64 -> TExp Bool) -> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \TExp Int64
from TExp Int64
to ->
(TExp Int64
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
from) TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int64
to TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64
segment_size)
[TExp Int64]
_ -> CrossesSegment
forall a. Maybe a
Nothing
SpaceId
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage1" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
[[VName]]
all_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_threads) [SegBinOp GPUMem]
scans
[SegBinOp GPUMem]
-> (SegBinOp GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOp GPUMem]
scans ((SegBinOp GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (SegBinOp GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \SegBinOp GPUMem
scan -> do
Maybe (Exp GPUMem)
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Mem rep =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp GPUMem)
forall a. Maybe a
Nothing (Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param LParamMem] -> Scope GPUMem)
-> [Param LParamMem] -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [LParam GPUMem])
-> Lambda GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
SpaceId
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op.
SpaceId
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor SpaceId
"j" TExp Int64
elems_per_thread ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
j -> do
TV Int64
chunk_offset <-
SpaceId -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"chunk_offset" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
j
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelGroupId KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group
TV Int64
flat_idx <-
SpaceId -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"flat_idx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants)
(VName -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> [VName] -> [TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
gtids ([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> [TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx
let per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp GPUMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem GPUMem]
[PatElemT LParamMem]
all_pes
in_bounds :: TExp Bool
in_bounds =
(TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64 -> TExp Bool)
-> [TExp Int64] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids) [TExp Int64]
dims'
when_in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds = Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
scans) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
per_scan_res :: [[KernelResult]]
per_scan_res =
[SegBinOp GPUMem] -> [KernelResult] -> [[KernelResult]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [KernelResult]
all_scan_res
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"write to-scan values to parameters" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
(([PatElemT LParamMem], SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ())
-> [([PatElemT LParamMem], SegBinOp GPUMem, [KernelResult])]
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([VName]
-> ([PatElem GPUMem], SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids) ([([PatElemT LParamMem], SegBinOp GPUMem, [KernelResult])]
-> ImpM GPUMem KernelEnv KernelOp ())
-> [([PatElemT LParamMem], SegBinOp GPUMem, [KernelResult])]
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[[PatElemT LParamMem]]
-> [SegBinOp GPUMem]
-> [[KernelResult]]
-> [([PatElemT LParamMem], SegBinOp GPUMem, [KernelResult])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT LParamMem]]
per_scan_pes [SegBinOp GPUMem]
scans [[KernelResult]]
per_scan_res
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"write mapped values results to global memory" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LParamMem, KernelResult)]
-> ((PatElemT LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [KernelResult] -> [(PatElemT LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem GPUMem]
[PatElemT LParamMem]
all_pes) [KernelResult]
map_res) (((PatElemT LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElemT LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, KernelResult
se) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
se)
[]
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"threads in bounds read input" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds
[([PatElemT LParamMem], SegBinOp GPUMem, [VName])]
-> (([PatElemT LParamMem], SegBinOp GPUMem, [VName])
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [SegBinOp GPUMem]
-> [[VName]]
-> [([PatElemT LParamMem], SegBinOp GPUMem, [VName])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT LParamMem]]
per_scan_pes [SegBinOp GPUMem]
scans [[VName]]
all_local_arrs) ((([PatElemT LParamMem], SegBinOp GPUMem, [VName])
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (([PatElemT LParamMem], SegBinOp GPUMem, [VName])
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\([PatElemT LParamMem]
pes, scan :: SegBinOp GPUMem
scan@(SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape), [VName]
local_arrs) ->
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"do one intra-group scan operation" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
let rets :: [TypeBase Shape NoUniqueness]
rets = Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_op
scan_x_params :: [LParam GPUMem]
scan_x_params = SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan
(Bool
array_scan, Fence
fence, ImpM GPUMem KernelEnv KernelOp ()
barrier) = Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op
Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier
Shape
-> ([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
vec_shape (([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"maybe restore some to-scan values to parameters, or read neutral" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
TExp Bool
in_bounds
( do
[TExp Int64]
-> [PatElem GPUMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readToScanValues ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) [PatElem GPUMem]
[PatElemT LParamMem]
pes SegBinOp GPUMem
scan
TExp Int64
-> [TExp Int64]
-> [TExp Int64]
-> [PatElem GPUMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readCarries (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset) [TExp Int64]
dims' [TExp Int64]
vec_is [PatElem GPUMem]
[PatElemT LParamMem]
pes SegBinOp GPUMem
scan
)
( [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
)
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"combine with carry and write to local memory" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem) -> BodyT GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
scan_op) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs (BodyT GPUMem -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT GPUMem -> [SubExp]) -> BodyT GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
scan_op)) (((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
se) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
se []
let crossesSegment' :: Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
crossesSegment' = do
TExp Int64 -> TExp Int64 -> TExp Bool
f <- CrossesSegment
crossesSegment
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
forall a. a -> Maybe a
Just ((TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool))
-> (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
from TPrimExp Int32 ExpLeaf
to ->
let from' :: TExp Int64
from' = TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
from TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
to' :: TExp Int64
to' = TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
in TExp Int64 -> TExp Int64 -> TExp Bool
f TExp Int64
from' TExp Int64
to'
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence
Lambda GPUMem
scan_op_renamed <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scan_op
Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> TExp Int64
-> TExp Int64
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
crossesSegment'
(TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TExp Int64)
-> TPrimExp Int32 ExpLeaf -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
num_threads)
(TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
Lambda GPUMem
scan_op_renamed
[VName]
local_arrs
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"threads in bounds write partial scan result" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)]
-> ((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [PatElemT LParamMem]
-> [VName]
-> [(TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [PatElemT LParamMem]
pes [VName]
local_arrs) (((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, PatElemT LParamMem
pe, VName
arr) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
(VName -> SubExp
Var VName
arr)
[KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]
ImpM GPUMem KernelEnv KernelOp ()
barrier
let load_carry :: ImpM GPUMem KernelEnv KernelOp ()
load_carry =
[(VName, Param LParamMem)]
-> ((VName, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Param LParamMem] -> [(VName, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_arrs [LParam GPUMem]
[Param LParamMem]
scan_x_params) (((VName, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LParamMem
p) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var VName
arr)
[ if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
then TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
else
(TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelGroupId KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1)
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
]
load_neutral :: ImpM GPUMem KernelEnv KernelOp ()
load_neutral =
[(SubExp, Param LParamMem)]
-> ((SubExp, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp] -> [Param LParamMem] -> [(SubExp, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [LParam GPUMem]
[Param LParamMem]
scan_x_params) (((SubExp, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((SubExp, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, Param LParamMem
p) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"first thread reads last element as carry-in for next iteration" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
TExp Bool
crosses_segment <- SpaceId -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"crosses_segment" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$
case CrossesSegment
crossesSegment of
CrossesSegment
Nothing -> TExp Bool
forall v. TPrimExp Bool v
false
Just TExp Int64 -> TExp Int64 -> TExp Bool
f ->
TExp Int64 -> TExp Int64 -> TExp Bool
f
( TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
-TExp Int64
1
)
( TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
chunk_offset
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
)
TExp Bool
should_load_carry <-
SpaceId -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"should_load_carry" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$
KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 ExpLeaf
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
crosses_segment
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
should_load_carry ImpM GPUMem KernelEnv KernelOp ()
load_carry
Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
should_load_carry ImpM GPUMem KernelEnv KernelOp ()
load_neutral
ImpM GPUMem KernelEnv KernelOp ()
barrier
(TV Int32, TExp Int64, CrossesSegment)
-> CallKernelGen (TV Int32, TExp Int64, CrossesSegment)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Int32
num_threads, TExp Int64
elems_per_group, CrossesSegment
crossesSegment)
scanStage2 ::
Pattern GPUMem ->
TV Int32 ->
Imp.TExp Int64 ->
Count NumGroups SubExp ->
CrossesSegment ->
SegSpace ->
[SegBinOp GPUMem] ->
CallKernelGen ()
scanStage2 :: Pattern GPUMem
-> TV Int32
-> TExp Int64
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage2 (Pattern [PatElem GPUMem]
_ [PatElem GPUMem]
all_pes) TV Int32
stage1_num_threads TExp Int64
elems_per_group Count NumGroups SubExp
num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
let crossesSegment' :: Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
crossesSegment' = do
TExp Int64 -> TExp Int64 -> TExp Bool
f <- CrossesSegment
crossesSegment
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
forall a. a -> Maybe a
Just ((TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool))
-> (TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
from TPrimExp Int32 ExpLeaf
to ->
TExp Int64 -> TExp Int64 -> TExp Bool
f
((TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
from TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1)
((TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1)
SpaceId
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage2" Count NumGroups (TExp Int64)
1 Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
[[VName]]
per_scan_local_arrs <- Count GroupSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays Count GroupSize SubExp
group_size (TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
stage1_num_threads) [SegBinOp GPUMem]
scans
let per_scan_rets :: [[TypeBase Shape NoUniqueness]]
per_scan_rets = (SegBinOp GPUMem -> [TypeBase Shape NoUniqueness])
-> [SegBinOp GPUMem] -> [[TypeBase Shape NoUniqueness]]
forall a b. (a -> b) -> [a] -> [b]
map (Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. LambdaT rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase Shape NoUniqueness])
-> (SegBinOp GPUMem -> Lambda GPUMem)
-> SegBinOp GPUMem
-> [TypeBase Shape NoUniqueness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
scans
per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp GPUMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem GPUMem]
[PatElemT LParamMem]
all_pes
TV Int64
flat_idx <-
SpaceId -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"flat_idx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
(VName -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> [VName] -> [TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
gtids ([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> [TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx
[(SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
[PatElemT LParamMem])]
-> ((SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
[PatElemT LParamMem])
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem]
-> [[VName]]
-> [[TypeBase Shape NoUniqueness]]
-> [[PatElemT LParamMem]]
-> [(SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
[PatElemT LParamMem])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SegBinOp GPUMem]
scans [[VName]]
per_scan_local_arrs [[TypeBase Shape NoUniqueness]]
per_scan_rets [[PatElemT LParamMem]]
per_scan_pes) (((SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
[PatElemT LParamMem])
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((SegBinOp GPUMem, [VName], [TypeBase Shape NoUniqueness],
[PatElemT LParamMem])
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape, [VName]
local_arrs, [TypeBase Shape NoUniqueness]
rets, [PatElemT LParamMem]
pes) ->
Shape
-> ([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
vec_shape (([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
let glob_is :: [TExp Int64]
glob_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is
in_bounds :: TExp Bool
in_bounds =
(TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64 -> TExp Bool)
-> [TExp Int64] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids) [TExp Int64]
dims'
when_in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds = [(TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)]
-> ((TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [PatElemT LParamMem]
-> [(TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [PatElemT LParamMem]
pes) (((TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, PatElemT LParamMem
pe) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
VName
arr
[KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
[TExp Int64]
glob_is
when_out_of_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_out_of_bounds = [(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [SubExp]
nes) (((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
ne) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
ne []
(Bool
_, Fence
_, ImpM GPUMem KernelEnv KernelOp ()
barrier) =
Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"threads in bound read carries; others get neutral element" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_out_of_bounds
ImpM GPUMem KernelEnv KernelOp ()
barrier
Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
-> TExp Int64
-> TExp Int64
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
groupScan
Maybe
(TPrimExp Int32 ExpLeaf -> TPrimExp Int32 ExpLeaf -> TExp Bool)
crossesSegment'
(TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 ExpLeaf -> TExp Int64)
-> TPrimExp Int32 ExpLeaf -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 ExpLeaf
forall t. TV t -> TExp t
tvExp TV Int32
stage1_num_threads)
(TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
Lambda GPUMem
scan_op
[VName]
local_arrs
SpaceId
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. SpaceId -> ImpM rep r op () -> ImpM rep r op ()
sComment SpaceId
"threads in bounds write scanned carries" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)]
-> ((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [PatElemT LParamMem]
-> [VName]
-> [(TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [PatElemT LParamMem]
pes [VName]
local_arrs) (((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, PatElemT LParamMem, VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, PatElemT LParamMem
pe, VName
arr) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
[TExp Int64]
glob_is
(VName -> SubExp
Var VName
arr)
[KernelConstants -> TypeBase Shape NoUniqueness -> TExp Int64
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]
scanStage3 ::
Pattern GPUMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
Imp.TExp Int64 ->
CrossesSegment ->
SegSpace ->
[SegBinOp GPUMem] ->
CallKernelGen ()
scanStage3 :: Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TExp Int64
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage3 (Pattern [PatElem GPUMem]
_ [PatElem GPUMem]
all_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TExp Int64
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans = do
let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
TPrimExp Int32 ExpLeaf
required_groups <-
SpaceId
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int32 ExpLeaf)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"required_groups" (TPrimExp Int32 ExpLeaf
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int32 ExpLeaf))
-> TPrimExp Int32 ExpLeaf
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int32 ExpLeaf)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TPrimExp Int32 ExpLeaf
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TPrimExp Int32 ExpLeaf)
-> TExp Int64 -> TPrimExp Int32 ExpLeaf
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims' TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size')
SpaceId
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage3" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
SegVirt
-> TPrimExp Int32 ExpLeaf
-> (TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
virtualiseGroups SegVirt
SegVirt TPrimExp Int32 ExpLeaf
required_groups ((TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (TPrimExp Int32 ExpLeaf -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 ExpLeaf
virt_group_id -> do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
TExp Int64
flat_idx <-
SpaceId
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"flat_idx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 ExpLeaf
virt_group_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size')
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TPrimExp Int32 ExpLeaf -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 ExpLeaf
kernelLocalThreadId KernelConstants
constants)
(VName -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> [VName] -> [TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
gtids ([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> [TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' TExp Int64
flat_idx
TV Int64
orig_group <- SpaceId -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"orig_group" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
flat_idx TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
elems_per_group
TV Int64
carry_in_flat_idx <-
SpaceId -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"carry_in_flat_idx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
orig_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
let carry_in_idx :: [TExp Int64]
carry_in_idx = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx
let in_bounds :: TExp Bool
in_bounds =
(TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64 -> TExp Bool)
-> [TExp Int64] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids) [TExp Int64]
dims'
crosses_segment :: TExp Bool
crosses_segment =
TExp Bool -> Maybe (TExp Bool) -> TExp Bool
forall a. a -> Maybe a -> a
fromMaybe TExp Bool
forall v. TPrimExp Bool v
false (Maybe (TExp Bool) -> TExp Bool) -> Maybe (TExp Bool) -> TExp Bool
forall a b. (a -> b) -> a -> b
$
CrossesSegment
crossesSegment
CrossesSegment
-> Maybe (TExp Int64) -> Maybe (TExp Int64 -> TExp Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TExp Int64 -> Maybe (TExp Int64)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx)
Maybe (TExp Int64 -> TExp Bool)
-> Maybe (TExp Int64) -> Maybe (TExp Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TExp Int64 -> Maybe (TExp Int64)
forall (f :: * -> *) a. Applicative f => a -> f a
pure TExp Int64
flat_idx
is_a_carry :: TExp Bool
is_a_carry = TExp Int64
flat_idx TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
orig_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
elems_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
no_carry_in :: TExp Bool
no_carry_in = TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
orig_group TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
is_a_carry TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
crosses_segment
let per_scan_pes :: [[PatElemT LParamMem]]
per_scan_pes = [SegBinOp GPUMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem GPUMem]
[PatElemT LParamMem]
all_pes
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[([PatElemT LParamMem], SegBinOp GPUMem)]
-> (([PatElemT LParamMem], SegBinOp GPUMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [SegBinOp GPUMem] -> [([PatElemT LParamMem], SegBinOp GPUMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElemT LParamMem]]
per_scan_pes [SegBinOp GPUMem]
scans) ((([PatElemT LParamMem], SegBinOp GPUMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> (([PatElemT LParamMem], SegBinOp GPUMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\([PatElemT LParamMem]
pes, SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape) -> do
Maybe (Exp GPUMem)
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Mem rep =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp GPUMem)
forall a. Maybe a
Nothing (Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param LParamMem] -> Scope GPUMem)
-> [Param LParamMem] -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
let ([Param LParamMem]
scan_x_params, [Param LParamMem]
scan_y_params) =
Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
Shape
-> ([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
vec_shape (([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ([TExp Int64] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
[(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params [PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
([TExp Int64]
carry_in_idx [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
[(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_y_params [PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
[Param LParamMem]
-> BodyT GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
scan_x_params (BodyT GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> BodyT GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
scan_op
[(Param LParamMem, PatElemT LParamMem)]
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElemT LParamMem] -> [(Param LParamMem, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params [PatElemT LParamMem]
pes) (((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElemT LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElemT LParamMem
pe) ->
VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gtids [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
compileSegScan ::
Pattern GPUMem ->
SegLevel ->
SegSpace ->
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan :: Pattern GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pattern GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = do
TV Int64
stage1_max_num_groups <- SpaceId -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall rep r op t. SpaceId -> PrimType -> ImpM rep r op (TV t)
dPrim SpaceId
"stage1_max_num_groups" PrimType
int64
HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
stage1_max_num_groups) SizeClass
SizeGroup
Count NumGroups SubExp
stage1_num_groups <-
(TV Int64 -> Count NumGroups SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (TV Int64 -> SubExp) -> TV Int64 -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize) (ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$
SpaceId -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"stage1_num_groups" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
stage1_max_num_groups) (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
Imp.unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
(TV Int32
stage1_num_threads, TExp Int64
elems_per_group, CrossesSegment
crossesSegment) <-
Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen (TV Int32, TExp Int64, CrossesSegment)
scanStage1 Pattern GPUMem
pat Count NumGroups SubExp
stage1_num_groups (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ SpaceId -> Maybe Exp -> Code HostOp
forall a. SpaceId -> Maybe Exp -> Code a
Imp.DebugPrint SpaceId
"elems_per_group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
elems_per_group
Pattern GPUMem
-> TV Int32
-> TExp Int64
-> Count NumGroups SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage2 Pattern GPUMem
pat TV Int32
stage1_num_threads TExp Int64
elems_per_group Count NumGroups SubExp
stage1_num_groups CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans
Pattern GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TExp Int64
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage3 Pattern GPUMem
pat (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) TExp Int64
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans