{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegRed
( compileSegRed,
compileSegRed',
DoSegBody,
)
where
import Control.Monad
import Data.List (genericLength, zip4)
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (chunks, mapAccumLM)
import Futhark.Util.IntegralExp (divUp, nextMul, quot, rem)
import Prelude hiding (quot, rem)
forM2_ :: (Monad m) => [a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ :: forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [a]
xs [b]
ys a -> b -> m c
f = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [a]
xs [b]
ys) (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry a -> b -> m c
f)
maxNumOps :: Int
maxNumOps :: Int
maxNumOps = Int
10
type DoSegBody = ([(SubExp, [Imp.TExp Int64])] -> InKernelGen ()) -> InKernelGen ()
data SegRedIntermediateArrays
= GeneralSegRedInterms
{ SegRedIntermediateArrays -> [VName]
groupRedArrs :: [VName]
}
| NoncommPrimSegRedInterms
{ SegRedIntermediateArrays -> [VName]
collCopyArrs :: [VName],
groupRedArrs :: [VName],
SegRedIntermediateArrays -> [VName]
privateChunks :: [VName]
}
compileSegRed ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegRed :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegRed Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
segbinops KernelBody GPUMem
map_kbody = do
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegRed" forall a. Maybe a
Nothing
KernelAttrs {kAttrNumGroups :: KernelAttrs -> Count NumGroups SubExp
kAttrNumGroups = Count NumGroups SubExp
num_groups, kAttrGroupSize :: KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize = Count GroupSize SubExp
group_size} <-
SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size
Pat LParamMem
-> KernelGrid
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pat LParamMem
pat KernelGrid
grid SegSpace
space [SegBinOp GPUMem]
segbinops forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont ->
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply map function" forall a b. (a -> b) -> a -> b
$
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
map_kbody) forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
segbinops) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
map_kbody
let mapout_arrs :: [PatElem LParamMem]
mapout_arrs = forall a. Int -> [a] -> [a]
drop (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
segbinops) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PatElem LParamMem]
mapout_arrs) forall a b. (a -> b) -> a -> b
$
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write map-out result(s)" forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LParamMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LParamMem]
mapout_arrs [KernelResult]
map_res
[(SubExp, [TPrimExp Int64 VName])] -> InKernelGen ()
red_cont forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map ((,[]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelResult -> SubExp
kernelResultSubExp) [KernelResult]
red_res
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" forall a. Maybe a
Nothing
paramOf :: SegBinOp GPUMem -> [Param LParamMem]
paramOf :: SegBinOp GPUMem -> [Param LParamMem]
paramOf (SegBinOp Commutativity
_ Lambda GPUMem
op [SubExp]
ne Shape
_) = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op
isPrimSegBinOp :: SegBinOp GPUMem -> Bool
isPrimSegBinOp :: SegBinOp GPUMem -> Bool
isPrimSegBinOp SegBinOp GPUMem
segbinop =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
segbinop)
Bool -> Bool -> Bool
&& forall a. ArrayShape a => a -> Int
shapeRank (forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
segbinop) forall a. Eq a => a -> a -> Bool
== Int
0
compileSegRed' ::
Pat LetDecMem ->
KernelGrid ->
SegSpace ->
[SegBinOp GPUMem] ->
DoSegBody ->
CallKernelGen ()
compileSegRed' :: Pat LParamMem
-> KernelGrid
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' Pat LParamMem
pat KernelGrid
grid SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont
| forall i a. Num i => [a] -> i
genericLength [SegBinOp GPUMem]
segbinops forall a. Ord a => a -> a -> Bool
> Int
maxNumOps =
forall a. String -> a
compilerLimitationS forall a b. (a -> b) -> a -> b
$
String
"compileSegRed': at most " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
maxNumOps forall a. [a] -> [a] -> [a]
++ String
" reduction operators are supported."
| Bool
otherwise = do
TV Int64
chunk_v <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"chunk_size" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelConstExp -> CallKernelGen Exp
kernelConstToExp KernelConstExp
chunk_const
case SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space of
[(VName
_, Constant (IntValue (Int64Value Int64
1))), (VName, SubExp)
_] ->
(TV Int64, KernelConstExp)
-> (Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ())
-> CallKernelGen ()
compileReduction (TV Int64
chunk_v, KernelConstExp
chunk_const) Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction
[(VName, SubExp)]
_ -> do
let segment_size :: TPrimExp Int64 VName
segment_size = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
use_small_segments :: TPrimExp Bool VName
use_small_segments = TPrimExp Int64 VName
segment_size forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
2 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size) forall a. Num a => a -> a -> a
* forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
TPrimExp Bool VName
use_small_segments
((TV Int64, KernelConstExp)
-> (Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ())
-> CallKernelGen ()
compileReduction (TV Int64
chunk_v, KernelConstExp
chunk_const) Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction)
((TV Int64, KernelConstExp)
-> (Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ())
-> CallKernelGen ()
compileReduction (TV Int64
chunk_v, KernelConstExp
chunk_const) Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction)
where
compileReduction :: (TV Int64, KernelConstExp)
-> (Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ())
-> CallKernelGen ()
compileReduction (TV Int64, KernelConstExp)
chunk Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
f =
Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
f Pat LParamMem
pat Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size (TV Int64, KernelConstExp)
chunk SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont
param_types :: [Type]
param_types = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOp GPUMem -> [Param LParamMem]
paramOf [SegBinOp GPUMem]
segbinops
num_groups :: Count NumGroups SubExp
num_groups = KernelGrid -> Count NumGroups SubExp
gridNumGroups KernelGrid
grid
group_size :: Count GroupSize SubExp
group_size = KernelGrid -> Count GroupSize SubExp
gridGroupSize KernelGrid
grid
chunk_const :: KernelConstExp
chunk_const =
if Commutativity
Noncommutative forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
segbinops
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
isPrimSegBinOp [SegBinOp GPUMem]
segbinops
then [Type] -> KernelConstExp
getChunkSize [Type]
param_types
else forall v. PrimValue -> PrimExp v
Imp.ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 (Int64
1 :: Int64)
makeIntermArrays ::
Imp.TExp Int64 ->
SubExp ->
SubExp ->
[SegBinOp GPUMem] ->
InKernelGen [SegRedIntermediateArrays]
makeIntermArrays :: TPrimExp Int64 VName
-> SubExp
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
makeIntermArrays TPrimExp Int64 VName
group_id SubExp
group_size SubExp
chunk [SegBinOp GPUMem]
segbinops
| Commutativity
Noncommutative <- forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
segbinops),
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
isPrimSegBinOp [SegBinOp GPUMem]
segbinops =
InKernelGen [SegRedIntermediateArrays]
noncommPrimSegRedInterms
| Bool
otherwise =
TPrimExp Int64 VName
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
generalSegRedInterms TPrimExp Int64 VName
group_id SubExp
group_size [SegBinOp GPUMem]
segbinops
where
params :: [[Param LParamMem]]
params = forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> [Param LParamMem]
paramOf [SegBinOp GPUMem]
segbinops
noncommPrimSegRedInterms :: InKernelGen [SegRedIntermediateArrays]
noncommPrimSegRedInterms = do
SubExp
group_worksize <- forall {k} (t :: k). TV t -> SubExp
tvSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"group_worksize" TPrimExp Int64 VName
group_worksize_E
let sum_ :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
sum_ TPrimExp Int64 VName
x TPrimExp Int64 VName
y = forall e. IntegralExp e => e -> e -> e
nextMul TPrimExp Int64 VName
x TPrimExp Int64 VName
y forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_size_E forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
y
group_reds_lmem_requirement :: TPrimExp Int64 VName
group_reds_lmem_requirement = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
sum_ TPrimExp Int64 VName
0 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[TPrimExp Int64 VName]]
elem_sizes
collcopy_lmem_requirement :: TPrimExp Int64 VName
collcopy_lmem_requirement = TPrimExp Int64 VName
group_worksize_E forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
max_elem_size
lmem_total_size :: Count Bytes (TPrimExp Int64 VName)
lmem_total_size =
forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
collcopy_lmem_requirement forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TPrimExp Int64 VName
group_reds_lmem_requirement
(TPrimExp Int64 VName
_, [[TPrimExp Int64 VName]]
offsets) <-
forall {m :: * -> *} {t :: * -> *} {t :: * -> *} {acc} {x} {y}.
(Monad m, Traversable t, Traversable t) =>
acc -> t (t x) -> (acc -> x -> m (acc, y)) -> m (acc, t (t y))
forAccumLM2D TPrimExp Int64 VName
0 [[TPrimExp Int64 VName]]
elem_sizes forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
byte_offs TPrimExp Int64 VName
elem_size ->
(,TPrimExp Int64 VName
byte_offs forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
elem_size)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"offset" (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
sum_ TPrimExp Int64 VName
byte_offs TPrimExp Int64 VName
elem_size)
VName
lmem <- forall rep r op.
String
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM rep r op VName
sAlloc String
"local_mem" Count Bytes (TPrimExp Int64 VName)
lmem_total_size (String -> Space
Space String
"local")
let arrInLMem :: PrimType
-> String
-> SubExp
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp VName
arrInLMem PrimType
ptype String
name SubExp
len_se TPrimExp Int64 VName
offset =
forall rep r op.
String -> PrimType -> Shape -> VName -> LMAD -> ImpM rep r op VName
sArray
(String
name forall a. [a] -> [a] -> [a]
++ String
"_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString PrimType
ptype)
PrimType
ptype
(forall d. [d] -> ShapeBase d
Shape [SubExp
len_se])
VName
lmem
forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
offset [SubExp -> TPrimExp Int64 VName
pe64 SubExp
len_se]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a b. [a] -> [b] -> [(a, b)]
zip [[Param LParamMem]]
params [[TPrimExp Int64 VName]]
offsets) forall a b. (a -> b) -> a -> b
$ \[(Param LParamMem, TPrimExp Int64 VName)]
ps_and_offsets -> do
([VName]
coll_copy_arrs, [VName]
group_red_arrs, [VName]
priv_chunks) <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Param LParamMem, TPrimExp Int64 VName)]
ps_and_offsets forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, TPrimExp Int64 VName
offset) -> do
let ptype :: PrimType
ptype = forall shape u. TypeBase shape u -> PrimType
elemType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p
(,,)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType
-> String
-> SubExp
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp VName
arrInLMem PrimType
ptype String
"coll_copy_arr" SubExp
group_worksize TPrimExp Int64 VName
0
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType
-> String
-> SubExp
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp VName
arrInLMem PrimType
ptype String
"group_red_arr" SubExp
group_size TPrimExp Int64 VName
offset
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray
(String
"chunk_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString PrimType
ptype)
PrimType
ptype
(forall d. [d] -> ShapeBase d
Shape [SubExp
chunk])
([SubExp] -> PrimType -> Space
ScalarSpace [SubExp
chunk] PrimType
ptype)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [VName] -> SegRedIntermediateArrays
NoncommPrimSegRedInterms [VName]
coll_copy_arrs [VName]
group_red_arrs [VName]
priv_chunks
group_size_E :: TPrimExp Int64 VName
group_size_E = SubExp -> TPrimExp Int64 VName
pe64 SubExp
group_size
group_worksize_E :: TPrimExp Int64 VName
group_worksize_E = TPrimExp Int64 VName
group_size_E forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
chunk
paramSize :: Param LParamMem -> TPrimExp Int64 VName
paramSize = forall a. Num a => PrimType -> a
primByteSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. TypeBase shape u -> PrimType
elemType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType
elem_sizes :: [[TPrimExp Int64 VName]]
elem_sizes = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> TPrimExp Int64 VName
paramSize) [[Param LParamMem]]
params
max_elem_size :: TPrimExp Int64 VName
max_elem_size = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[TPrimExp Int64 VName]]
elem_sizes
forAccumLM2D :: acc -> t (t x) -> (acc -> x -> m (acc, y)) -> m (acc, t (t y))
forAccumLM2D acc
acc t (t x)
ls acc -> x -> m (acc, y)
f = forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM (forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM acc -> x -> m (acc, y)
f) acc
acc t (t x)
ls
generalSegRedInterms ::
Imp.TExp Int64 ->
SubExp ->
[SegBinOp GPUMem] ->
InKernelGen [SegRedIntermediateArrays]
generalSegRedInterms :: TPrimExp Int64 VName
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
generalSegRedInterms TPrimExp Int64 VName
group_id SubExp
group_size [SegBinOp GPUMem]
segbinops =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (a -> b) -> [a] -> [b]
map [VName] -> SegRedIntermediateArrays
GeneralSegRedInterms) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> [Param LParamMem]
paramOf [SegBinOp GPUMem]
segbinops) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p ->
case forall dec. Param dec -> dec
paramDec Param LParamMem
p of
MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
let shape' :: Shape
shape' = forall d. [d] -> ShapeBase d
Shape [SubExp
group_size] forall a. Semigroup a => a -> a -> a
<> Shape
shape
let shape_E :: [TPrimExp Int64 VName]
shape_E = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape'
forall rep r op.
String -> PrimType -> Shape -> VName -> LMAD -> ImpM rep r op VName
sArray (String
"red_arr_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString PrimType
pt) PrimType
pt Shape
shape' VName
mem forall a b. (a -> b) -> a -> b
$
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota (TPrimExp Int64 VName
group_id forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
shape_E) [TPrimExp Int64 VName]
shape_E
LParamMem
_ -> do
let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p
shape :: Shape
shape = forall d. [d] -> ShapeBase d
Shape [SubExp
group_size]
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray (String
"red_arr_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString PrimType
pt) PrimType
pt Shape
shape forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
groupResultArrays ::
SubExp ->
SubExp ->
[SegBinOp GPUMem] ->
CallKernelGen [[VName]]
groupResultArrays :: SubExp -> SubExp -> [SegBinOp GPUMem] -> CallKernelGen [[VName]]
groupResultArrays SubExp
num_virtgroups SubExp
group_size [SegBinOp GPUMem]
segbinops =
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp GPUMem]
segbinops forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
lam [SubExp]
_ Shape
shape) ->
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPUMem
lam) forall a b. (a -> b) -> a -> b
$ \Type
t -> do
let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType Type
t
extra_dim :: SubExp
extra_dim
| forall shape u. TypeBase shape u -> Bool
primType Type
t, forall a. ArrayShape a => a -> Int
shapeRank Shape
shape forall a. Eq a => a -> a -> Bool
== Int
0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
| Bool
otherwise = SubExp
group_size
full_shape :: Shape
full_shape = forall d. [d] -> ShapeBase d
Shape [SubExp
extra_dim, SubExp
num_virtgroups] forall a. Semigroup a => a -> a -> a
<> Shape
shape forall a. Semigroup a => a -> a -> a
<> forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
perm :: [Int]
perm = [Int
1 .. forall a. ArrayShape a => a -> Int
shapeRank Shape
full_shape forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
0]
forall rep r op.
String
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm String
"segred_tmp" PrimType
pt Shape
full_shape (String -> Space
Space String
"device") [Int]
perm
type DoCompileSegRed =
Pat LetDecMem ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
(TV Int64, Imp.KernelConstExp) ->
SegSpace ->
[SegBinOp GPUMem] ->
DoSegBody ->
CallKernelGen ()
nonsegmentedReduction :: DoCompileSegRed
nonsegmentedReduction :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
nonsegmentedReduction (Pat [PatElem LParamMem]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size (TV Int64
chunk_v, KernelConstExp
chunk_const) SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont = do
let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
chunk :: TPrimExp Int64 VName
chunk = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v
num_groups_se :: SubExp
num_groups_se = forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups SubExp
num_groups
group_size_se :: SubExp
group_size_se = forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
group_size' :: TPrimExp Int64 VName
group_size' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
group_size_se
global_tid :: TPrimExp Int64 VName
global_tid = forall a. a -> TPrimExp Int64 a
Imp.le64 forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
n :: TPrimExp Int64 VName
n = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [SubExp]
dims
VName
counters <- String -> Int -> ImpM GPUMem HostEnv HostOp VName
genZeroes String
"counters" Int
maxNumOps
[[VName]]
reds_group_res_arrs <- SubExp -> SubExp -> [SegBinOp GPUMem] -> CallKernelGen [[VName]]
groupResultArrays SubExp
num_groups_se SubExp
group_size_se [SegBinOp GPUMem]
segbinops
SubExp
num_threads <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k). TV t -> SubExp
tvSize forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_groups_se forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size'
let attrs :: KernelAttrs
attrs =
(Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size)
{ kAttrConstExps :: Map VName KernelConstExp
kAttrConstExps = forall k a. k -> a -> Map k a
M.singleton (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
chunk_v) KernelConstExp
chunk_const
}
String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segred_nonseg" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
let ltid :: TPrimExp Int32 VName
ltid = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
let group_id :: TPrimExp Int32 VName
group_id = KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants
[SegRedIntermediateArrays]
interms <- TPrimExp Int64 VName
-> SubExp
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
makeIntermArrays (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id) SubExp
group_size_se (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v) [SegBinOp GPUMem]
segbinops
VName
sync_arr <- forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool (forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
gtids forall a b. (a -> b) -> a -> b
$ \VName
v -> forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TPrimExp Int64 VName
0 :: Imp.TExp Int64)
TPrimExp Int64 VName
q <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"q" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
n forall e. IntegralExp e => e -> e -> e
`divUp` (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelNumThreads KernelConstants
constants) forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
chunk)
[SegBinOpSlug]
slugs <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, SegRedIntermediateArrays, [VName])
-> InKernelGen SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
ltid TPrimExp Int32 VName
group_id) forall a b. (a -> b) -> a -> b
$
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
segbinops [SegRedIntermediateArrays]
interms [[VName]]
reds_group_res_arrs
[Lambda GPUMem]
new_lambdas <-
[VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
[VName]
gtids
TPrimExp Int64 VName
n
TPrimExp Int64 VName
global_tid
TPrimExp Int64 VName
q
TPrimExp Int64 VName
chunk
(SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads)
[SegBinOpSlug]
slugs
DoSegBody
map_body_cont
let segred_pess :: [[PatElem LParamMem]]
segred_pess =
forall a. [Int] -> [a] -> [[a]]
chunks
(forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
segbinops)
[PatElem LParamMem]
segred_pes
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[PatElem LParamMem]]
segred_pess [SegBinOpSlug]
slugs [Lambda GPUMem]
new_lambdas [Integer
0 ..]) forall a b. (a -> b) -> a -> b
$
\([PatElem LParamMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
new_lambda, Integer
i) ->
[PatElem LParamMem]
-> TPrimExp Int32 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> SegBinOpSlug
-> Lambda GPUMem
-> VName
-> VName
-> TPrimExp Int64 VName
-> InKernelGen ()
reductionStageTwo
[PatElem LParamMem]
pes
TPrimExp Int32 VName
group_id
[TPrimExp Int64 VName
0]
TPrimExp Int64 VName
0
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelNumGroups KernelConstants
constants)
SegBinOpSlug
slug
Lambda GPUMem
new_lambda
VName
counters
VName
sync_arr
(forall a. Num a => Integer -> a
fromInteger Integer
i)
smallSegmentsReduction :: DoCompileSegRed
smallSegmentsReduction :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
smallSegmentsReduction (Pat [PatElem LParamMem]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size (TV Int64, KernelConstExp)
_ SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont = do
let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
TPrimExp Int64 VName
segment_size_nonzero <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"segment_size_nonzero" forall a b. (a -> b) -> a -> b
$ forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 TPrimExp Int64 VName
segment_size
let group_size_se :: SubExp
group_size_se = forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
num_groups_se :: SubExp
num_groups_se = forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
num_groups' :: TPrimExp Int64 VName
num_groups' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_groups_se
group_size' :: TPrimExp Int64 VName
group_size' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
group_size_se
SubExp
num_threads <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k). TV t -> SubExp
tvSize forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_threads" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
num_groups' forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size'
let num_segments :: TPrimExp Int64 VName
num_segments = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims'
segments_per_group :: TPrimExp Int64 VName
segments_per_group = TPrimExp Int64 VName
group_size' forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
segment_size_nonzero
required_groups :: TPrimExp Int32 VName
required_groups = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
num_segments forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName
segments_per_group
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"# SegRed-small" forall a. Maybe a
Nothing
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_segments
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
segment_size
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segments_per_group" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
segments_per_group
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"required_groups" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int32 VName
required_groups
String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segred_small" (SegSpace -> VName
segFlat SegSpace
space) (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
let group_id :: TPrimExp Int64 VName
group_id = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
[SegRedIntermediateArrays]
interms <- TPrimExp Int64 VName
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
generalSegRedInterms TPrimExp Int64 VName
group_id SubExp
group_size_se [SegBinOp GPUMem]
segbinops
let reds_arrs :: [[VName]]
reds_arrs = forall a b. (a -> b) -> [a] -> [b]
map SegRedIntermediateArrays -> [VName]
groupRedArrs [SegRedIntermediateArrays]
interms
SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TPrimExp Int32 VName
required_groups forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
virtgroup_id -> do
let segment_index :: TPrimExp Int64 VName
segment_index =
(TPrimExp Int64 VName
ltid forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
segment_size_nonzero)
forall a. Num a => a -> a -> a
+ (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virtgroup_id forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segments_per_group)
index_within_segment :: TPrimExp Int64 VName
index_within_segment = TPrimExp Int64 VName
ltid forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [a] -> [a]
init [VName]
gtids) (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims')) TPrimExp Int64 VName
segment_index
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a. [a] -> a
last [VName]
gtids) TPrimExp Int64 VName
index_within_segment
let in_bounds :: InKernelGen ()
in_bounds =
DoSegBody
map_body_cont forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])]
red_res ->
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save results to be reduced" forall a b. (a -> b) -> a -> b
$ do
let red_dests :: [(VName, [TPrimExp Int64 VName])]
red_dests = forall a b. (a -> b) -> [a] -> [b]
map (,[TPrimExp Int64 VName
ltid]) (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [(VName, [TPrimExp Int64 VName])]
red_dests [(SubExp, [TPrimExp Int64 VName])]
red_res forall a b. (a -> b) -> a -> b
$ \(VName
d, [TPrimExp Int64 VName]
d_is) (SubExp
res, [TPrimExp Int64 VName]
res_is) ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
d [TPrimExp Int64 VName]
d_is SubExp
res [TPrimExp Int64 VName]
res_is
out_of_bounds :: InKernelGen ()
out_of_bounds =
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOp GPUMem]
segbinops [[VName]]
reds_arrs forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
_ [SubExp]
nes Shape
_) [VName]
red_arrs ->
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [VName]
red_arrs [SubExp]
nes forall a b. (a -> b) -> a -> b
$ \VName
arr SubExp
ne ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] SubExp
ne []
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply map function if in bounds" forall a b. (a -> b) -> a -> b
$
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
( TPrimExp Int64 VName
segment_size
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 VName
0
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(VName, SubExp)] -> TPrimExp Bool VName
isActive (forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids [SubExp]
dims)
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
ltid
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
segment_size
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_group
)
InKernelGen ()
in_bounds
InKernelGen ()
out_of_bounds
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let crossesSegment :: TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
crossesSegment TPrimExp Int32 VName
from TPrimExp Int32 VName
to =
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from) forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
segment_size forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TPrimExp Int64 VName
0) forall a b. (a -> b) -> a -> b
$
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform segmented scan to imitate reduction" forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOp GPUMem]
segbinops [[VName]]
reds_arrs forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda GPUMem
red_op [SubExp]
_ Shape
_) [VName]
red_arrs ->
Maybe
(TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Bool VName)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
(forall a. a -> Maybe a
Just TPrimExp Int32 VName -> TPrimExp Int32 VName -> TPrimExp Bool VName
crossesSegment)
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads)
(TPrimExp Int64 VName
segment_size forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_group)
Lambda GPUMem
red_op
[VName]
red_arrs
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save final values of segments"
forall a b. (a -> b) -> a -> b
$ forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen
( forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virtgroup_id
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_group
forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
ltid
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
num_segments
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
ltid
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
segments_per_group
)
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [PatElem LParamMem]
segred_pes (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
reds_arrs)
forall a b. (a -> b) -> a -> b
$ \PatElem LParamMem
pe VName
arr -> do
let flat_segment_index :: TPrimExp Int64 VName
flat_segment_index =
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virtgroup_id forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segments_per_group forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
ltid
gtids' :: [TPrimExp Int64 VName]
gtids' =
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') TPrimExp Int64 VName
flat_segment_index
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
[TPrimExp Int64 VName]
gtids'
(VName -> SubExp
Var VName
arr)
[(TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
segment_size_nonzero forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
largeSegmentsReduction :: DoCompileSegRed
largeSegmentsReduction :: Pat LParamMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> (TV Int64, KernelConstExp)
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
largeSegmentsReduction (Pat [PatElem LParamMem]
segred_pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size (TV Int64
chunk_v, KernelConstExp
chunk_const) SegSpace
space [SegBinOp GPUMem]
segbinops DoSegBody
map_body_cont = do
let ([VName]
gtids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
num_segments :: TPrimExp Int64 VName
num_segments = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims'
segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
num_groups' :: TPrimExp Int64 VName
num_groups' = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count NumGroups SubExp
num_groups
group_size_se :: SubExp
group_size_se = forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
group_size' :: TPrimExp Int64 VName
group_size' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
group_size_se
chunk :: TPrimExp Int64 VName
chunk = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v
TPrimExp Int64 VName
groups_per_segment <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"groups_per_segment" forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
num_groups' forall e. IntegralExp e => e -> e -> e
`divUp` forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 TPrimExp Int64 VName
num_segments
TPrimExp Int64 VName
q <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"q" forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
segment_size forall e. IntegralExp e => e -> e -> e
`divUp` (TPrimExp Int64 VName
group_size' forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
groups_per_segment forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
chunk)
TV Int64
num_virtgroups <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_virtgroups" forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
groups_per_segment forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_segments
TPrimExp Int64 VName
threads_per_segment <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"threads_per_segment" forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
groups_per_segment forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size'
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"# SegRed-large" forall a. Maybe a
Nothing
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_segments" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_segments
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"segment_size" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
segment_size
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_virtgroups" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_virtgroups
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"num_groups" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_groups'
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"group_size" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
group_size'
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"q" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
q
forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"groups_per_segment" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
groups_per_segment
[[VName]]
reds_group_res_arrs <- SubExp -> SubExp -> [SegBinOp GPUMem] -> CallKernelGen [[VName]]
groupResultArrays (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
num_virtgroups) SubExp
group_size_se [SegBinOp GPUMem]
segbinops
let num_counters :: Int
num_counters = Int
maxNumOps forall a. Num a => a -> a -> a
* Int
1024
VName
counters <- String -> Int -> ImpM GPUMem HostEnv HostOp VName
genZeroes String
"counters" forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters
let attrs :: KernelAttrs
attrs =
(Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size)
{ kAttrConstExps :: Map VName KernelConstExp
kAttrConstExps = forall k a. k -> a -> Map k a
M.singleton (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
chunk_v) KernelConstExp
chunk_const
}
String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segred_large" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
let group_id :: TPrimExp Int64 VName
group_id = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelGroupId KernelConstants
constants
ltid :: TPrimExp Int32 VName
ltid = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
[SegRedIntermediateArrays]
interms <- TPrimExp Int64 VName
-> SubExp
-> SubExp
-> [SegBinOp GPUMem]
-> InKernelGen [SegRedIntermediateArrays]
makeIntermArrays TPrimExp Int64 VName
group_id SubExp
group_size_se (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v) [SegBinOp GPUMem]
segbinops
VName
sync_arr <- forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"sync_arr" PrimType
Bool (forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1]) forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
num_virtgroups)) forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
virtgroup_id -> do
let segment_gtids :: [VName]
segment_gtids = forall a. [a] -> [a]
init [VName]
gtids
TPrimExp Int64 VName
flat_segment_id <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_segment_id" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virtgroup_id forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
groups_per_segment
TPrimExp Int64 VName
global_tid <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" forall a b. (a -> b) -> a -> b
$
(forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
virtgroup_id forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
group_size' forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
ltid)
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
threads_per_segment
let first_group_for_segment :: TPrimExp Int64 VName
first_group_for_segment = TPrimExp Int64 VName
flat_segment_id forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
groups_per_segment
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
segment_gtids (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims')) TPrimExp Int64 VName
flat_segment_id
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (forall a. [a] -> a
last [VName]
gtids) PrimType
int64
let n :: TPrimExp Int64 VName
n = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [SubExp]
dims
[SegBinOpSlug]
slugs <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, SegRedIntermediateArrays, [VName])
-> InKernelGen SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
ltid TPrimExp Int32 VName
virtgroup_id) forall a b. (a -> b) -> a -> b
$
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegBinOp GPUMem]
segbinops [SegRedIntermediateArrays]
interms [[VName]]
reds_group_res_arrs
[Lambda GPUMem]
new_lambdas <-
[VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne
[VName]
gtids
TPrimExp Int64 VName
n
TPrimExp Int64 VName
global_tid
TPrimExp Int64 VName
q
TPrimExp Int64 VName
chunk
TPrimExp Int64 VName
threads_per_segment
[SegBinOpSlug]
slugs
DoSegBody
map_body_cont
let segred_pess :: [[PatElem LParamMem]]
segred_pess =
forall a. [Int] -> [a] -> [[a]]
chunks
(forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
segbinops)
[PatElem LParamMem]
segred_pes
multiple_groups_per_segment :: InKernelGen ()
multiple_groups_per_segment =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[PatElem LParamMem]]
segred_pess [SegBinOpSlug]
slugs [Lambda GPUMem]
new_lambdas [Int
0 ..]) forall a b. (a -> b) -> a -> b
$
\([PatElem LParamMem]
pes, SegBinOpSlug
slug, Lambda GPUMem
new_lambda, Int
i) -> do
let counter_idx :: TPrimExp Int64 VName
counter_idx =
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
i forall a. Num a => a -> a -> a
* Int
num_counters)
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
flat_segment_id
forall e. IntegralExp e => e -> e -> e
`rem` forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_counters
[PatElem LParamMem]
-> TPrimExp Int32 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> SegBinOpSlug
-> Lambda GPUMem
-> VName
-> VName
-> TPrimExp Int64 VName
-> InKernelGen ()
reductionStageTwo
[PatElem LParamMem]
pes
TPrimExp Int32 VName
virtgroup_id
(forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_gtids)
TPrimExp Int64 VName
first_group_for_segment
TPrimExp Int64 VName
groups_per_segment
SegBinOpSlug
slug
Lambda GPUMem
new_lambda
VName
counters
VName
sync_arr
TPrimExp Int64 VName
counter_idx
one_group_per_segment :: InKernelGen ()
one_group_per_segment =
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"first thread in group saves final result to memory" forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOpSlug]
slugs [[PatElem LParamMem]]
segred_pess forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug [PatElem LParamMem]
pes ->
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
ltid forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [PatElem LParamMem]
pes (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \PatElem LParamMem
v (VName
acc, [TPrimExp Int64 VName]
acc_is) ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
v) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
segment_gtids) (VName -> SubExp
Var VName
acc) [TPrimExp Int64 VName]
acc_is
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TPrimExp Int64 VName
groups_per_segment forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1) InKernelGen ()
one_group_per_segment InKernelGen ()
multiple_groups_per_segment
data SegBinOpSlug = SegBinOpSlug
{ SegBinOpSlug -> SegBinOp GPUMem
slugOp :: SegBinOp GPUMem,
SegBinOpSlug -> SegRedIntermediateArrays
slugInterms :: SegRedIntermediateArrays,
SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs :: [(VName, [Imp.TExp Int64])],
SegBinOpSlug -> [VName]
groupResArrs :: [VName]
}
segBinOpSlug ::
Imp.TExp Int32 ->
Imp.TExp Int32 ->
(SegBinOp GPUMem, SegRedIntermediateArrays, [VName]) ->
InKernelGen SegBinOpSlug
segBinOpSlug :: TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> (SegBinOp GPUMem, SegRedIntermediateArrays, [VName])
-> InKernelGen SegBinOpSlug
segBinOpSlug TPrimExp Int32 VName
ltid TPrimExp Int32 VName
group_id (SegBinOp GPUMem
op, SegRedIntermediateArrays
interms, [VName]
group_res_arrs) = do
[(VName, [TPrimExp Int64 VName])]
accs <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param LParamMem
-> VName
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 VName])
mkAcc (forall rep. Lambda rep -> [LParam rep]
lambdaParams (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)) [VName]
group_res_arrs
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem
-> SegRedIntermediateArrays
-> [(VName, [TPrimExp Int64 VName])]
-> [VName]
-> SegBinOpSlug
SegBinOpSlug SegBinOp GPUMem
op SegRedIntermediateArrays
interms [(VName, [TPrimExp Int64 VName])]
accs [VName]
group_res_arrs
where
mkAcc :: Param LParamMem
-> VName
-> ImpM GPUMem KernelEnv KernelOp (VName, [TPrimExp Int64 VName])
mkAcc Param LParamMem
p VName
group_res_arr
| Prim PrimType
t <- forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p,
forall a. ArrayShape a => a -> Int
shapeRank (forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
op) forall a. Eq a => a -> a -> Bool
== Int
0 = do
TV Any
group_res_acc <- forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim (VName -> String
baseString (forall dec. Param dec -> VName
paramName Param LParamMem
p) forall a. Semigroup a => a -> a -> a
<> String
"_group_res_acc") PrimType
t
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (t :: k). TV t -> VName
tvVar TV Any
group_res_acc, [])
| Bool
otherwise =
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
group_res_arr, [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
ltid, forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id])
slugLambda :: SegBinOpSlug -> Lambda GPUMem
slugLambda :: SegBinOpSlug -> Lambda GPUMem
slugLambda = forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugBody :: SegBinOpSlug -> Body GPUMem
slugBody :: SegBinOpSlug -> Body GPUMem
slugBody = forall rep. Lambda rep -> Body rep
lambdaBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> Lambda GPUMem
slugLambda
slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams :: SegBinOpSlug -> [LParam GPUMem]
slugParams = forall rep. Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> Lambda GPUMem
slugLambda
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = forall rep. SegBinOp rep -> Shape
segBinOpShape forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm :: [SegBinOpSlug] -> Commutativity
slugsComm = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall rep. SegBinOp rep -> Commutativity
segBinOpComm forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp)
slugSplitParams :: SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams :: SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams SegBinOpSlug
slug = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug
slugGroupRedArrs :: SegBinOpSlug -> [VName]
slugGroupRedArrs :: SegBinOpSlug -> [VName]
slugGroupRedArrs = SegRedIntermediateArrays -> [VName]
groupRedArrs forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegRedIntermediateArrays
slugInterms
slugPrivChunks :: SegBinOpSlug -> [VName]
slugPrivChunks :: SegBinOpSlug -> [VName]
slugPrivChunks = SegRedIntermediateArrays -> [VName]
privateChunks forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegRedIntermediateArrays
slugInterms
slugCollCopyArrs :: SegBinOpSlug -> [VName]
slugCollCopyArrs :: SegBinOpSlug -> [VName]
slugCollCopyArrs = SegRedIntermediateArrays -> [VName]
collCopyArrs forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegRedIntermediateArrays
slugInterms
reductionStageOne ::
[VName] ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
[SegBinOpSlug] ->
DoSegBody ->
InKernelGen [Lambda GPUMem]
reductionStageOne :: [VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [SegBinOpSlug]
-> DoSegBody
-> InKernelGen [Lambda GPUMem]
reductionStageOne [VName]
gtids TPrimExp Int64 VName
n TPrimExp Int64 VName
global_tid TPrimExp Int64 VName
q TPrimExp Int64 VName
chunk TPrimExp Int64 VName
threads_per_segment [SegBinOpSlug]
slugs DoSegBody
body_cont = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
let glb_ind_var :: TV Int64
glb_ind_var = forall {k} (t :: k). VName -> PrimType -> TV t
mkTV (forall a. [a] -> a
last [VName]
gtids) PrimType
int64
ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam GPUMem]
slugParams [SegBinOpSlug]
slugs
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"ne-initialise the outer (per-group) accumulator(s)" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \(VName
acc, [TPrimExp Int64 VName]
acc_is) SubExp
ne ->
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) SubExp
ne []
[Lambda GPUMem]
new_lambdas <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> Lambda GPUMem
slugLambda) [SegBinOpSlug]
slugs
let group_size :: TPrimExp Int32 VName
group_size = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
let doGroupReduce :: InKernelGen ()
doGroupReduce =
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOpSlug]
slugs [Lambda GPUMem]
new_lambdas forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug Lambda GPUMem
new_lambda -> do
let accs :: [(VName, [TPrimExp Int64 VName])]
accs = SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug
let params :: [LParam GPUMem]
params = SegBinOpSlug -> [LParam GPUMem]
slugParams SegBinOpSlug
slug
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
let group_red_arrs :: [VName]
group_red_arrs = SegBinOpSlug -> [VName]
slugGroupRedArrs SegBinOpSlug
slug
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"store accs. prims go in lmem; non-prims in params (in global mem)" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
group_red_arrs [(VName, [TPrimExp Int64 VName])]
accs [LParam GPUMem]
params) forall a b. (a -> b) -> a -> b
$
\(VName
arr, (VName
acc, [TPrimExp Int64 VName]
acc_is), Param LParamMem
p) ->
if forall p. Typed p => Param p -> Bool
isPrimParam Param LParamMem
p
then forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
else forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
TPrimExp Int32 VName -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce TPrimExp Int32 VName
group_size Lambda GPUMem
new_lambda [VName]
group_red_arrs
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"thread 0 updates per-group acc(s); rest reset to ne" forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TPrimExp Int64 VName
ltid forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
( forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [(VName, [TPrimExp Int64 VName])]
accs (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
new_lambda) forall a b. (a -> b) -> a -> b
$
\(VName
acc, [TPrimExp Int64 VName]
acc_is) Param LParamMem
p ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p) []
)
( forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [(VName, [TPrimExp Int64 VName])]
accs (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
\(VName
acc, [TPrimExp Int64 VName]
acc_is) SubExp
ne ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) SubExp
ne []
)
case ([SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs, forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SegBinOp GPUMem -> Bool
isPrimSegBinOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp GPUMem
slugOp) [SegBinOpSlug]
slugs) of
(Commutativity
Noncommutative, Bool
True) ->
[SegBinOpSlug]
-> DoSegBody
-> TV Int64
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> InKernelGen ()
-> InKernelGen ()
noncommPrimParamsStageOneBody
[SegBinOpSlug]
slugs
DoSegBody
body_cont
TV Int64
glb_ind_var
TPrimExp Int64 VName
global_tid
TPrimExp Int64 VName
q
TPrimExp Int64 VName
n
TPrimExp Int64 VName
chunk
InKernelGen ()
doGroupReduce
(Commutativity, Bool)
_ ->
[SegBinOpSlug]
-> DoSegBody
-> TV Int64
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> InKernelGen ()
-> InKernelGen ()
generalStageOneBody
[SegBinOpSlug]
slugs
DoSegBody
body_cont
TV Int64
glb_ind_var
TPrimExp Int64 VName
global_tid
TPrimExp Int64 VName
q
TPrimExp Int64 VName
n
TPrimExp Int64 VName
threads_per_segment
InKernelGen ()
doGroupReduce
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Lambda GPUMem]
new_lambdas
generalStageOneBody ::
[SegBinOpSlug] ->
DoSegBody ->
TV Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
InKernelGen () ->
InKernelGen ()
generalStageOneBody :: [SegBinOpSlug]
-> DoSegBody
-> TV Int64
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> InKernelGen ()
-> InKernelGen ()
generalStageOneBody [SegBinOpSlug]
slugs DoSegBody
body_cont TV Int64
glb_ind_var TPrimExp Int64 VName
global_tid TPrimExp Int64 VName
q TPrimExp Int64 VName
n TPrimExp Int64 VName
threads_per_segment InKernelGen ()
doGroupReduce = do
let is_comm :: Bool
is_comm = [SegBinOpSlug] -> Commutativity
slugsComm [SegBinOpSlug]
slugs forall a. Eq a => a -> a -> Bool
== Commutativity
Commutative
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
let group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
TPrimExp Int64 VName
group_id_in_segment <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_id_in_segment" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
global_tid forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
group_size
TPrimExp Int64 VName
group_base_offset <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_base_offset" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_id_in_segment forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
q forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
q forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
group_offset <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_offset" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_base_offset forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size
TV Int64
glb_ind_var
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- if Bool
is_comm
then TPrimExp Int64 VName
global_tid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
threads_per_segment forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i
else TPrimExp Int64 VName
group_offset forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
glb_ind_var forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) forall a b. (a -> b) -> a -> b
$
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply map function(s)" forall a b. (a -> b) -> a -> b
$
DoSegBody
body_cont forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])]
all_red_res -> do
let maps_res :: [[(SubExp, [TPrimExp Int64 VName])]]
maps_res = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [TPrimExp Int64 VName])]
all_red_res
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOpSlug]
slugs [[(SubExp, [TPrimExp Int64 VName])]]
maps_res forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug [(SubExp, [TPrimExp Int64 VName])]
map_res ->
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
let ([LParam GPUMem]
acc_params, [LParam GPUMem]
next_params) = SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams SegBinOpSlug
slug
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load accumulator(s)" forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
acc_params (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p (VName
acc, [TPrimExp Int64 VName]
acc_is) ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load next value(s)" forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
next_params [(SubExp, [TPrimExp Int64 VName])]
map_res forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p (SubExp
res, [TPrimExp Int64 VName]
res_is) ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TPrimExp Int64 VName]
res_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply reduction operator(s)"
forall a b. (a -> b) -> a -> b
$ forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug)
forall a b. (a -> b) -> a -> b
$ forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"store in accumulator(s)"
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_
(SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug)
(forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug)
forall a b. (a -> b) -> a -> b
$ \(VName
acc, [TPrimExp Int64 VName]
acc_is) SubExp
se ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
acc ([TPrimExp Int64 VName]
acc_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) SubExp
se []
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
is_comm InKernelGen ()
doGroupReduce
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
is_comm InKernelGen ()
doGroupReduce
noncommPrimParamsStageOneBody ::
[SegBinOpSlug] ->
DoSegBody ->
TV Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
InKernelGen () ->
InKernelGen ()
noncommPrimParamsStageOneBody :: [SegBinOpSlug]
-> DoSegBody
-> TV Int64
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> InKernelGen ()
-> InKernelGen ()
noncommPrimParamsStageOneBody [SegBinOpSlug]
slugs DoSegBody
body_cont TV Int64
glb_ind_var TPrimExp Int64 VName
global_tid TPrimExp Int64 VName
q TPrimExp Int64 VName
n TPrimExp Int64 VName
chunk InKernelGen ()
doLMemGroupReduce = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
let group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
TPrimExp Int64 VName
group_id_in_segment <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_offset_in_segment" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
global_tid forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
group_size
TPrimExp Int64 VName
group_stride <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_stride" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_size forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
chunk
TPrimExp Int64 VName
group_base_offset <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_base_offset" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_id_in_segment forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
q forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_stride
let chunkLoop :: (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop = forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"k" TPrimExp Int64 VName
chunk
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
q forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
group_offset <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_offset" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_base_offset forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_stride
(TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
k -> do
TPrimExp Int64 VName
loc_ind <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"loc_ind" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
k forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
TV Int64
glb_ind_var forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int64 VName
group_offset forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
loc_ind
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
glb_ind_var forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n)
( DoSegBody
body_cont forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TPrimExp Int64 VName])]
all_red_res -> do
let slugs_res :: [[(SubExp, [TPrimExp Int64 VName])]]
slugs_res = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [SubExp]
slugNeutral) [SegBinOpSlug]
slugs) [(SubExp, [TPrimExp Int64 VName])]
all_red_res
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [SegBinOpSlug]
slugs [[(SubExp, [TPrimExp Int64 VName])]]
slugs_res forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug [(SubExp, [TPrimExp Int64 VName])]
slug_res -> do
let priv_chunks :: [VName]
priv_chunks = SegBinOpSlug -> [VName]
slugPrivChunks SegBinOpSlug
slug
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write map result(s) to private chunk(s)" forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [VName]
priv_chunks [(SubExp, [TPrimExp Int64 VName])]
slug_res forall a b. (a -> b) -> a -> b
$ \VName
priv_chunk (SubExp
res, [TPrimExp Int64 VName]
res_is) ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
priv_chunk [TPrimExp Int64 VName
k] SubExp
res [TPrimExp Int64 VName]
res_is
)
( forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug ->
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ (SegBinOpSlug -> [VName]
slugPrivChunks SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
\VName
priv_chunk SubExp
ne ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
priv_chunk [TPrimExp Int64 VName
k] SubExp
ne []
)
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"effectualize collective copies in local memory" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug -> do
let coll_copy_arrs :: [VName]
coll_copy_arrs = SegBinOpSlug -> [VName]
slugCollCopyArrs SegBinOpSlug
slug
let priv_chunks :: [VName]
priv_chunks = SegBinOpSlug -> [VName]
slugPrivChunks SegBinOpSlug
slug
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [VName]
coll_copy_arrs [VName]
priv_chunks forall a b. (a -> b) -> a -> b
$ \VName
lmem_arr VName
priv_chunk -> do
(TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
k -> do
TPrimExp Int64 VName
lmem_idx <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"lmem_idx" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
k forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
group_size
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
lmem_arr [TPrimExp Int64 VName
lmem_idx] (VName -> SubExp
Var VName
priv_chunk) [TPrimExp Int64 VName
k]
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
(TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
k -> do
TPrimExp Int64 VName
lmem_idx <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"lmem_idx" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
chunk forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
k
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
priv_chunk [TPrimExp Int64 VName
k] (VName -> SubExp
Var VName
lmem_arr) [TPrimExp Int64 VName
lmem_idx]
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"per-thread sequential reduction of private chunk(s)" forall a b. (a -> b) -> a -> b
$ do
(TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
chunkLoop forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
k ->
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOpSlug]
slugs forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug -> do
let accs :: [VName]
accs = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug
let ([LParam GPUMem]
acc_ps, [LParam GPUMem]
next_ps) = SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams SegBinOpSlug
slug
let ps_accs_chunks :: [(Param LParamMem, Param LParamMem, VName, VName)]
ps_accs_chunks = forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LParam GPUMem]
acc_ps [LParam GPUMem]
next_ps [VName]
accs (SegBinOpSlug -> [VName]
slugPrivChunks SegBinOpSlug
slug)
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load params for all reductions" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param LParamMem, Param LParamMem, VName, VName)]
ps_accs_chunks forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_p, Param LParamMem
next_p, VName
acc, VName
priv_chunk) -> do
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
acc_p) [] (VName -> SubExp
Var VName
acc) []
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
next_p) [] (VName -> SubExp
Var VName
priv_chunk) [TPrimExp Int64 VName
k]
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply reduction operator(s)" forall a b. (a -> b) -> a -> b
$ do
let binop_ress :: [SubExp]
binop_ress = forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [VName]
accs [SubExp]
binop_ress forall a b. (a -> b) -> a -> b
$ \VName
acc SubExp
binop_res ->
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
acc [] SubExp
binop_res []
InKernelGen ()
doLMemGroupReduce
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
reductionStageTwo ::
[PatElem LetDecMem] ->
Imp.TExp Int32 ->
[Imp.TExp Int64] ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
SegBinOpSlug ->
Lambda GPUMem ->
VName ->
VName ->
Imp.TExp Int64 ->
InKernelGen ()
reductionStageTwo :: [PatElem LParamMem]
-> TPrimExp Int32 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> SegBinOpSlug
-> Lambda GPUMem
-> VName
-> VName
-> TPrimExp Int64 VName
-> InKernelGen ()
reductionStageTwo [PatElem LParamMem]
segred_pes TPrimExp Int32 VName
group_id [TPrimExp Int64 VName]
segment_gtids TPrimExp Int64 VName
first_group_for_segment TPrimExp Int64 VName
groups_per_segment SegBinOpSlug
slug Lambda GPUMem
new_lambda VName
counters VName
sync_arr TPrimExp Int64 VName
counter_idx = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
let ltid32 :: TPrimExp Int32 VName
ltid32 = KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants
ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
ltid32
group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
let ([LParam GPUMem]
acc_params, [LParam GPUMem]
next_params) = SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem])
slugSplitParams SegBinOpSlug
slug
let nes :: [SubExp]
nes = SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug
let red_arrs :: [VName]
red_arrs = SegBinOpSlug -> [VName]
slugGroupRedArrs SegBinOpSlug
slug
let group_res_arrs :: [VName]
group_res_arrs = SegBinOpSlug -> [VName]
groupResArrs SegBinOpSlug
slug
TV Int64
old_counter <- forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"old_counter" PrimType
int32
(VName
counter_mem, Space
_, Count Elements (TPrimExp Int64 VName)
counter_offset) <-
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray
VName
counters
[TPrimExp Int64 VName
counter_idx]
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"first thread in group saves group result to global memory" forall a b. (a -> b) -> a -> b
$
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
group_res_arrs (SegBinOpSlug -> [(VName, [TPrimExp Int64 VName])]
slugAccs SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(VName
v, (VName
acc, [TPrimExp Int64 VName]
acc_is)) ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
v [TPrimExp Int64 VName
0, forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id] (VName -> SubExp
Var VName
acc) [TPrimExp Int64 VName]
acc_is
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
forall op rep r. op -> ImpM rep r op ()
sOp
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace
forall a b. (a -> b) -> a -> b
$ IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd
IntType
Int32
(forall {k} (t :: k). TV t -> VName
tvVar TV Int64
old_counter)
VName
counter_mem
Count Elements (TPrimExp Int64 VName)
counter_offset
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 VName
1 :: Imp.TExp Int32)
forall rep r op.
VName -> [TPrimExp Int64 VName] -> Exp -> ImpM rep r op ()
sWrite VName
sync_arr [TPrimExp Int64 VName
0] forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
old_counter forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
groups_per_segment forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
TV Bool
is_last_group <- forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"is_last_group" PrimType
Bool
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Bool
is_last_group) [] (VName -> SubExp
Var VName
sync_arr) [TPrimExp Int64 VName
0]
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall {k} (t :: k). TV t -> TExp t
tvExp TV Bool
is_last_group) forall a b. (a -> b) -> a -> b
$ do
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) forall a b. (a -> b) -> a -> b
$
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
DefaultSpace forall a b. (a -> b) -> a -> b
$
IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd IntType
Int32 (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
old_counter) VName
counter_mem Count Elements (TPrimExp Int64 VName)
counter_offset forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
forall a. Num a => a -> a
negate TPrimExp Int64 VName
groups_per_segment
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
forall op rep r. op -> ImpM rep r op ()
sOp (Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal)
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"read in the per-group-results" forall a b. (a -> b) -> a -> b
$ do
TPrimExp Int64 VName
read_per_thread <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"read_per_thread" forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
groups_per_segment forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
group_size
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
acc_params [SubExp]
nes forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p SubExp
ne ->
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TPrimExp Int64 VName
read_per_thread forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
TPrimExp Int64 VName
group_res_id <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"group_res_id" forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
read_per_thread forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
i
TPrimExp Int64 VName
index_of_group_res <-
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"index_of_group_res" forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
first_group_for_segment forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_res_id
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
group_res_id forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
groups_per_segment) forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
next_params [VName]
group_res_arrs forall a b. (a -> b) -> a -> b
$
\Param LParamMem
p VName
group_res_arr ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var VName
group_res_arr)
([TPrimExp Int64 VName
0, TPrimExp Int64 VName
index_of_group_res] forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
acc_params (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body GPUMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p SubExp
se ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
se []
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [LParam GPUMem]
acc_params [VName]
red_arrs forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p VName
arr ->
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall p. Typed p => Param p -> Bool
isPrimParam Param LParamMem
p) forall a b. (a -> b) -> a -> b
$
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName
ltid] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p) []
forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"reduce the per-group results" forall a b. (a -> b) -> a -> b
$ do
TPrimExp Int32 VName -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
group_size) Lambda GPUMem
new_lambda [VName]
red_arrs
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"and back to memory with the final result" forall a b. (a -> b) -> a -> b
$
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int32 VName
ltid32 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int32 VName
0) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Monad m =>
[a] -> [b] -> (a -> b -> m c) -> m ()
forM2_ [PatElem LParamMem]
segred_pes (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
new_lambda) forall a b. (a -> b) -> a -> b
$ \PatElem LParamMem
pe Param LParamMem
p ->
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
([TPrimExp Int64 VName]
segment_gtids forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
(VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]