module Futhark.CodeGen.ImpGen.Multicore.SegRed
( compileSegRed,
compileSegRed',
DoSegBody,
)
where
import Control.Monad
import Futhark.CodeGen.ImpCode.Multicore qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Transform.Rename (renameLambda)
import Prelude hiding (quot, rem)
type DoSegBody = (([[(SubExp, [Imp.TExp Int64])]] -> MulticoreGen ()) -> MulticoreGen ())
compileSegRed ::
Pat LetDecMem ->
SegSpace ->
[SegBinOp MCMem] ->
KernelBody MCMem ->
TV Int32 ->
MulticoreGen Imp.MCCode
compileSegRed :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen MCCode
compileSegRed Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody TV Int32
nsubtasks =
Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen MCCode
compileSegRed' Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks forall a b. (a -> b) -> a -> b
$ \[[(SubExp, [TExp Int64])]] -> MulticoreGen ()
red_cont ->
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp MCMem]
reds) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" forall a b. (a -> b) -> a -> b
$ do
let map_arrs :: [PatElem LParamMem]
map_arrs = forall a. Int -> [a] -> [a]
drop (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp MCMem]
reds) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LParamMem -> KernelResult -> MulticoreGen ()
compileThreadResult SegSpace
space) [PatElem LParamMem]
map_arrs [KernelResult]
map_res
[[(SubExp, [TExp Int64])]] -> MulticoreGen ()
red_cont forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
reds forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat []
compileSegRed' ::
Pat LetDecMem ->
SegSpace ->
[SegBinOp MCMem] ->
TV Int32 ->
DoSegBody ->
MulticoreGen Imp.MCCode
compileSegRed' :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen MCCode
compileSegRed' Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody
| [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen MCCode
nonsegmentedReduction Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody
| Bool
otherwise =
Pat LParamMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen MCCode
segmentedReduction Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody
data SegBinOpSlug = SegBinOpSlug
{ SegBinOpSlug -> SegBinOp MCMem
slugOp :: SegBinOp MCMem,
SegBinOpSlug -> [VName]
slugResArrs :: [VName]
}
slugBody :: SegBinOpSlug -> Body MCMem
slugBody :: SegBinOpSlug -> Body MCMem
slugBody = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp
slugParams :: SegBinOpSlug -> [LParam MCMem]
slugParams :: SegBinOpSlug -> [LParam MCMem]
slugParams = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral :: SegBinOpSlug -> [SubExp]
slugNeutral = forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp
slugShape :: SegBinOpSlug -> Shape
slugShape :: SegBinOpSlug -> Shape
slugShape = forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> SegBinOp MCMem
slugOp
accParams, nextParams :: SegBinOpSlug -> [LParam MCMem]
accParams :: SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam MCMem]
slugParams SegBinOpSlug
slug
nextParams :: SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> [LParam MCMem]
slugParams SegBinOpSlug
slug
renameSlug :: SegBinOpSlug -> MulticoreGen SegBinOpSlug
renameSlug :: SegBinOpSlug -> MulticoreGen SegBinOpSlug
renameSlug SegBinOpSlug
slug = do
let op :: SegBinOp MCMem
op = SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug
let lambda :: Lambda MCMem
lambda = forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp MCMem
op
Lambda MCMem
lambda' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda MCMem
lambda
let op' :: SegBinOp MCMem
op' = SegBinOp MCMem
op {segBinOpLambda :: Lambda MCMem
segBinOpLambda = Lambda MCMem
lambda'}
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegBinOpSlug
slug {slugOp :: SegBinOp MCMem
slugOp = SegBinOp MCMem
op'}
nonsegmentedReduction ::
Pat LetDecMem ->
SegSpace ->
[SegBinOp MCMem] ->
TV Int32 ->
DoSegBody ->
MulticoreGen Imp.MCCode
nonsegmentedReduction :: Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen MCCode
nonsegmentedReduction Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds TV Int32
nsubtasks DoSegBody
kbody = forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
[[VName]]
thread_res_arrs <- String -> SubExp -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
groupResultArrays String
"reduce_stage_1_tid_res_arr" (forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
nsubtasks) [SegBinOp MCMem]
reds
let slugs1 :: [SegBinOpSlug]
slugs1 = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SegBinOp MCMem -> [VName] -> SegBinOpSlug
SegBinOpSlug [SegBinOp MCMem]
reds [[VName]]
thread_res_arrs
nsubtasks' :: TExp Int32
nsubtasks' = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
nsubtasks
let comm :: Bool
comm = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== Commutativity
Commutative) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Commutativity
segBinOpComm) [SegBinOp MCMem]
reds
let dims :: [[SubExp]]
dims = forall a b. (a -> b) -> [a] -> [b]
map (forall d. ShapeBase d -> [d]
shapeDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> Shape
slugShape) [SegBinOpSlug]
slugs1
let isScalar :: MemInfo d u ret -> Bool
isScalar MemInfo d u ret
x = case MemInfo d u ret
x of MemPrim PrimType
_ -> Bool
True; MemInfo d u ret
_ -> Bool
False
let scalars :: Bool
scalars = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {d} {u} {ret}. MemInfo d u ret -> Bool
isScalar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> dec
paramDec) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOpSlug -> [LParam MCMem]
slugParams) [SegBinOpSlug]
slugs1 Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== []) [[SubExp]]
dims
let inner_map :: Bool
inner_map = [] forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [[SubExp]]
dims
let path :: ReductionStage1
path
| Bool
comm Bool -> Bool -> Bool
&& Bool
scalars = ReductionStage1
reductionStage1CommScalar
| Bool
inner_map = ReductionStage1
reductionStage1Array
| Bool
scalars = ReductionStage1
reductionStage1NonCommScalar
| Bool
otherwise = ReductionStage1
reductionStage1Fallback
ReductionStage1
path SegSpace
space [SegBinOpSlug]
slugs1 DoSegBody
kbody
[SegBinOp MCMem]
reds2 <- [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
reds
let slugs2 :: [SegBinOpSlug]
slugs2 = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SegBinOp MCMem -> [VName] -> SegBinOpSlug
SegBinOpSlug [SegBinOp MCMem]
reds2 [[VName]]
thread_res_arrs
Pat LParamMem
-> SegSpace -> TExp Int32 -> [SegBinOpSlug] -> MulticoreGen ()
reductionStage2 Pat LParamMem
pat SegSpace
space TExp Int32
nsubtasks' [SegBinOpSlug]
slugs2
genBinOpParams :: [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams :: [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs =
forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam MCMem]
slugParams [SegBinOpSlug]
slugs
genAccumulators :: [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators :: [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs =
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOpSlug]
slugs forall a b. (a -> b) -> a -> b
$ \SegBinOpSlug
slug -> do
let shape :: Shape
shape = forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) (SegBinOpSlug -> [SubExp]
slugNeutral SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) -> do
VName
acc <-
case forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p of
Prim PrimType
pt
| Shape
shape forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty ->
forall {k} (t :: k). TV t -> VName
tvVar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"local_acc" PrimType
pt
| Bool
otherwise ->
forall {k} (rep :: k) r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"local_acc" PrimType
pt Shape
shape Space
DefaultSpace
Type
_ ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param LParamMem
p
forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc [TExp Int64]
vec_is SubExp
ne []
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
acc
data RedLoopType
= RedSeq
| RedComm
| RedNonComm
| RedNested
| RedUniformize
getRedLoop ::
RedLoopType ->
Imp.TExp Int64 ->
(Imp.TExp Int64 -> MulticoreGen ()) ->
MulticoreGen ()
getRedLoop :: RedLoopType
-> TExp Int64 -> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
getRedLoop RedLoopType
RedNonComm TExp Int64
_ = (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
generateUniformizeLoop
getRedLoop RedLoopType
RedUniformize TExp Int64
uni = \TExp Int64 -> MulticoreGen ()
body -> TExp Int64 -> MulticoreGen ()
body TExp Int64
uni
getRedLoop RedLoopType
_ TExp Int64
_ = \TExp Int64 -> MulticoreGen ()
body -> TExp Int64 -> MulticoreGen ()
body TExp Int64
0
getExtract ::
RedLoopType ->
Imp.TExp Int64 ->
MulticoreGen Imp.MCCode ->
MulticoreGen ()
RedLoopType
RedNonComm = TExp Int64 -> MulticoreGen MCCode -> MulticoreGen ()
extractVectorLane
getExtract RedLoopType
RedUniformize = TExp Int64 -> MulticoreGen MCCode -> MulticoreGen ()
extractVectorLane
getExtract RedLoopType
_ = \TExp Int64
_ MulticoreGen MCCode
body -> MulticoreGen MCCode
body forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit
getNestLoop ::
RedLoopType ->
Shape ->
([Imp.TExp Int64] -> MulticoreGen ()) ->
MulticoreGen ()
getNestLoop :: RedLoopType
-> Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
getNestLoop RedLoopType
RedNested = Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
sLoopNestVectorized
getNestLoop RedLoopType
_ = forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest
redSourceAccs :: [[VName]] -> DoSegBody
redSourceAccs :: [[VName]] -> DoSegBody
redSourceAccs [[VName]]
slug_local_accs [[(SubExp, [TExp Int64])]] -> MulticoreGen ()
m =
[[(SubExp, [TExp Int64])]] -> MulticoreGen ()
m forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a -> b) -> [a] -> [b]
map (\VName
x -> (VName -> SubExp
Var VName
x, []))) [[VName]]
slug_local_accs
genPostbodyReductionLoop ::
[[VName]] ->
[SegBinOpSlug] ->
[[VName]] ->
SegSpace ->
Imp.TExp Int64 ->
MulticoreGen ()
genPostbodyReductionLoop :: [[VName]]
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genPostbodyReductionLoop [[VName]]
accs =
RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedUniformize ([[VName]] -> DoSegBody
redSourceAccs [[VName]]
accs)
genReductionLoop ::
RedLoopType ->
DoSegBody ->
[SegBinOpSlug] ->
[[VName]] ->
SegSpace ->
Imp.TExp Int64 ->
MulticoreGen ()
genReductionLoop :: RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
typ DoSegBody
kbodymap [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space TExp Int64
i = do
let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns' :: [TExp Int64]
ns' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns' TExp Int64
i
DoSegBody
kbodymap forall a b. (a -> b) -> a -> b
$ \[[(SubExp, [TExp Int64])]]
all_red_res' -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[(SubExp, [TExp Int64])]]
all_red_res' [SegBinOpSlug]
slugs [[VName]]
slug_local_accs) forall a b. (a -> b) -> a -> b
$ \([(SubExp, [TExp Int64])]
red_res, SegBinOpSlug
slug, [VName]
local_accs) ->
RedLoopType
-> Shape -> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
getNestLoop RedLoopType
typ (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
let lamtypes :: [Type]
lamtypes = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> SegBinOp MCMem
slugOp SegBinOpSlug
slug
RedLoopType
-> TExp Int64 -> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
getRedLoop RedLoopType
typ TExp Int64
i forall a b. (a -> b) -> a -> b
$ \TExp Int64
uni -> do
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Load accum params" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) [VName]
local_accs [Type]
lamtypes) forall a b. (a -> b) -> a -> b
$
\(Param LParamMem
p, VName
local_acc, Type
t) ->
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall shape u. TypeBase shape u -> Bool
primType Type
t) forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
local_acc) [TExp Int64]
vec_is
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Load next params" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug) [(SubExp, [TExp Int64])]
red_res) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
res, [TExp Int64]
res_is)) -> do
RedLoopType -> TExp Int64 -> MulticoreGen MCCode -> MulticoreGen ()
getExtract RedLoopType
typ TExp Int64
uni forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TExp Int64]
res_is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"SegRed body" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_accs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
\(VName
local_acc, SubExp
se) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
local_acc [TExp Int64]
vec_is SubExp
se []
genWriteBack :: [SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack :: [SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[VName]]
slug_local_accs) forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [VName]
local_accs) ->
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [VName]
slugResArrs SegBinOpSlug
slug) [VName]
local_accs) forall a b. (a -> b) -> a -> b
$ \(VName
acc, VName
local_acc) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
acc [forall a. a -> TPrimExp Int64 a
Imp.le64 forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space] (VName -> SubExp
Var VName
local_acc) []
type ReductionStage1 = SegSpace -> [SegBinOpSlug] -> DoSegBody -> MulticoreGen ()
reductionStage1Fallback :: ReductionStage1
reductionStage1Fallback :: ReductionStage1
reductionStage1Fallback SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
MCCode
fbody <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
[SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs
[[VName]]
slug_local_accs <- [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs
String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$
RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedSeq DoSegBody
kbody [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
[SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
[Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segred_stage_1" MCCode
fbody [Param]
free_params
reductionStage1NonCommScalar :: ReductionStage1
reductionStage1NonCommScalar :: ReductionStage1
reductionStage1NonCommScalar SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
MCCode
fbody <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$ do
[SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs
[[VName]]
slug_local_accs <- [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs
String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$
RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedNonComm DoSegBody
kbody [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
[SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
[Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segred_stage_1" MCCode
fbody [Param]
free_params
reductionStage1CommScalar :: ReductionStage1
reductionStage1CommScalar :: ReductionStage1
reductionStage1CommScalar SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
MCCode
fbody <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
[SegBinOpSlug]
slugs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOpSlug -> MulticoreGen SegBinOpSlug
renameSlug [SegBinOpSlug]
slugs
MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$ do
[SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs'
[[VName]]
slug_local_accs_uni <- [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs'
[SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs
[[VName]]
slug_local_accs <- [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs
String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$
RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedComm DoSegBody
kbody [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
(TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
generateUniformizeLoop forall a b. (a -> b) -> a -> b
$
[[VName]]
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genPostbodyReductionLoop [[VName]]
slug_local_accs [SegBinOpSlug]
slugs' [[VName]]
slug_local_accs_uni SegSpace
space
[SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs_uni SegSpace
space
[Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segred_stage_1" MCCode
fbody [Param]
free_params
reductionStage1Array :: ReductionStage1
reductionStage1Array :: ReductionStage1
reductionStage1Array SegSpace
space [SegBinOpSlug]
slugs DoSegBody
kbody = do
MCCode
fbody <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
MCCode
lparams <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ [SegBinOpSlug] -> MulticoreGen ()
genBinOpParams [SegBinOpSlug]
slugs
([[VName]]
slug_local_accs, MCCode
uniform_prebody) <- forall {k} (rep :: k) r op a.
ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' forall a b. (a -> b) -> a -> b
$ [SegBinOpSlug] -> MulticoreGen [[VName]]
genAccumulators [SegBinOpSlug]
slugs
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit MCCode
uniform_prebody
MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$ do
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit MCCode
lparams
String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$
RedLoopType
-> DoSegBody
-> [SegBinOpSlug]
-> [[VName]]
-> SegSpace
-> TExp Int64
-> MulticoreGen ()
genReductionLoop RedLoopType
RedNested DoSegBody
kbody [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
[SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen ()
genWriteBack [SegBinOpSlug]
slugs [[VName]]
slug_local_accs SegSpace
space
[Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
fbody
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segred_stage_1" MCCode
fbody [Param]
free_params
reductionStage2 ::
Pat LetDecMem ->
SegSpace ->
Imp.TExp Int32 ->
[SegBinOpSlug] ->
MulticoreGen ()
reductionStage2 :: Pat LParamMem
-> SegSpace -> TExp Int32 -> [SegBinOpSlug] -> MulticoreGen ()
reductionStage2 Pat LParamMem
pat SegSpace
space TExp Int32
nsubtasks [SegBinOpSlug]
slugs = do
let per_red_pes :: [[PatElem LParamMem]]
per_red_pes = forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks (forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
phys_id :: TExp Int64
phys_id = forall a. a -> TPrimExp Int64 a
Imp.le64 (SegSpace -> VName
segFlat SegSpace
space)
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"neutral-initialise the output" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map SegBinOpSlug -> SegBinOp MCMem
slugOp [SegBinOpSlug]
slugs) [[PatElem LParamMem]]
per_red_pes) forall a b. (a -> b) -> a -> b
$ \(SegBinOp MCMem
red, [PatElem LParamMem]
red_res) ->
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_res forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
ne) ->
forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp MCMem
red) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TExp Int64]
vec_is SubExp
ne []
forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOpSlug -> [LParam MCMem]
slugParams [SegBinOpSlug]
slugs
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int32
nsubtasks forall a b. (a -> b) -> a -> b
$ \TExp Int32
i' -> do
forall {k} (t :: k). VName -> PrimType -> TV t
mkTV (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64 forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
i'
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Apply main thread reduction" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOpSlug]
slugs [[PatElem LParamMem]]
per_red_pes) forall a b. (a -> b) -> a -> b
$ \(SegBinOpSlug
slug, [PatElem LParamMem]
red_res) ->
forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (SegBinOpSlug -> Shape
slugShape SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load acc params" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
accParams SegBinOpSlug
slug) [PatElem LParamMem]
red_res) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TExp Int64]
vec_is
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load next params" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOpSlug -> [LParam MCMem]
nextParams SegBinOpSlug
slug) (SegBinOpSlug -> [VName]
slugResArrs SegBinOpSlug
slug)) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
acc) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
acc) (TExp Int64
phys_id forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is)
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"red body" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_res forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ SegBinOpSlug -> Body MCMem
slugBody SegBinOpSlug
slug) forall a b. (a -> b) -> a -> b
$
\(PatElem LParamMem
pe, SubExp
se') -> forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TExp Int64]
vec_is SubExp
se' []
segmentedReduction ::
Pat LetDecMem ->
SegSpace ->
[SegBinOp MCMem] ->
DoSegBody ->
MulticoreGen Imp.MCCode
segmentedReduction :: Pat LParamMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen MCCode
segmentedReduction Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody =
forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
MCCode
body <- Pat LParamMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen MCCode
compileSegRedBody Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody
[Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segmented_segred" MCCode
body [Param]
free_params
compileSegRedBody ::
Pat LetDecMem ->
SegSpace ->
[SegBinOp MCMem] ->
DoSegBody ->
MulticoreGen Imp.MCCode
compileSegRedBody :: Pat LParamMem
-> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen MCCode
compileSegRedBody Pat LParamMem
pat SegSpace
space [SegBinOp MCMem]
reds DoSegBody
kbody = do
let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
inner_bound :: TExp Int64
inner_bound = forall a. [a] -> a
last [TExp Int64]
ns_64
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
let per_red_pes :: [[PatElem LParamMem]]
per_red_pes = forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem]
reds forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall b c a. (b -> c) -> (a -> b) -> a -> c
. MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$
String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$ \TExp Int64
n_segments -> do
TExp Int64
flat_idx <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_idx" forall a b. (a -> b) -> a -> b
$ TExp Int64
n_segments forall a. Num a => a -> a -> a
* TExp Int64
inner_bound
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
flat_idx
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"neutral-initialise the accumulators" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElem LParamMem]]
per_red_pes [SegBinOp MCMem]
reds) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
pes, SegBinOp MCMem
red) ->
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
pes (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
ne) ->
forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp MCMem
red) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
is) forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
ne []
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"main body" forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp MCMem]
reds
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
inner_bound forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
TV t -> TExp t -> ImpM rep r op ()
(<--)
(forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k). VName -> PrimType -> TV t
`mkTV` PrimType
int64) forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [VName]
is)
(forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. [a] -> [a]
init [TExp Int64]
ns_64) (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
n_segments))
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a. [a] -> a
last [VName]
is) TExp Int64
i
DoSegBody
kbody forall a b. (a -> b) -> a -> b
$ \[[(SubExp, [TExp Int64])]]
red_res' -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_red_pes [SegBinOp MCMem]
reds [[(SubExp, [TExp Int64])]]
red_res') forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
pes, SegBinOp MCMem
red, [(SubExp, [TExp Int64])]
res') ->
forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp MCMem
red) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load accum" forall a b. (a -> b) -> a -> b
$ do
let acc_params :: [Param LParamMem]
acc_params = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) forall a b. (a -> b) -> a -> b
$ (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) SegBinOp MCMem
red
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
acc_params [PatElem LParamMem]
pes) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
is) forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"load new val" forall a b. (a -> b) -> a -> b
$ do
let next_params :: [Param LParamMem]
next_params = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp MCMem
red)) forall a b. (a -> b) -> a -> b
$ (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) SegBinOp MCMem
red
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
next_params [(SubExp, [TExp Int64])]
res') forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
res, [TExp Int64]
res_is)) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res ([TExp Int64]
res_is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"apply reduction" forall a b. (a -> b) -> a -> b
$ do
let lbody :: Body MCMem
lbody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) SegBinOp MCMem
red
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body MCMem
lbody) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write back to res" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
pes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body MCMem
lbody) forall a b. (a -> b) -> a -> b
$
\(PatElem LParamMem
pe, SubExp
se') -> forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
is) forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se' []