module Futhark.CodeGen.ImpGen.Multicore.SegHist
( compileSegHist,
)
where
import Control.Monad
import Data.List (zip4)
import Futhark.CodeGen.ImpCode.Multicore qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.CodeGen.ImpGen.Multicore.SegRed (compileSegRed')
import Futhark.IR.MCMem
import Futhark.MonadFreshNames
import Futhark.Transform.Rename (renameLambda)
import Futhark.Util (chunks, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (rem)
import Prelude hiding (quot, rem)
compileSegHist ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
TV Int32 ->
MulticoreGen Imp.MCCode
compileSegHist :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen (Code Multicore)
compileSegHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
| [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen (Code Multicore)
nonsegmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
| Bool
otherwise =
Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
segmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody
segHistOpChunks :: [HistOp rep] -> [a] -> [[a]]
segHistOpChunks :: forall {k} (rep :: k) a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks = forall a. [Int] -> [a] -> [[a]]
chunks forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral)
histSize :: HistOp MCMem -> Imp.TExp Int64
histSize :: HistOp MCMem -> TExp Int64
histSize = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. ShapeBase d -> [d]
shapeDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape
genHistOpParams :: HistOp MCMem -> MulticoreGen ()
genHistOpParams :: HistOp MCMem -> MulticoreGen ()
genHistOpParams HistOp MCMem
histops =
forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp MCMem
histops
renameHistop :: HistOp MCMem -> MulticoreGen (HistOp MCMem)
renameHistop :: HistOp MCMem -> MulticoreGen (HistOp MCMem)
renameHistop HistOp MCMem
histop = do
let op :: Lambda MCMem
op = forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp MCMem
histop
Lambda MCMem
lambda' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda MCMem
op
forall (f :: * -> *) a. Applicative f => a -> f a
pure HistOp MCMem
histop {histOp :: Lambda MCMem
histOp = Lambda MCMem
lambda'}
nonsegmentedHist ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
TV Int32 ->
MulticoreGen Imp.MCCode
nonsegmentedHist :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen (Code Multicore)
nonsegmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
num_histos = do
let ns :: [SubExp]
ns = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
num_histos' :: TExp Int32
num_histos' = forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_histos
hist_width :: TExp Int64
hist_width = HistOp MCMem -> TExp Int64
histSize forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [HistOp MCMem]
histops
use_subhistogram :: TPrimExp Bool VName
use_subhistogram = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_histos' forall a. Num a => a -> a -> a
* TExp Int64
hist_width forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64
[HistOp MCMem]
histops' <- [HistOp MCMem] -> MulticoreGen [HistOp MCMem]
renameHistOpLambda [HistOp MCMem]
histops
forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sUnless (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
TPrimExp Bool VName
use_subhistogram
(Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> TV Int32
-> KernelBody MCMem
-> MulticoreGen ()
subHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops TV Int32
num_histos KernelBody MCMem
kbody)
(Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen ()
atomicHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops' KernelBody MCMem
kbody)
onOpAtomic :: HistOp MCMem -> MulticoreGen ([VName] -> [Imp.TExp Int64] -> MulticoreGen ())
onOpAtomic :: HistOp MCMem
-> MulticoreGen ([VName] -> [TExp Int64] -> MulticoreGen ())
onOpAtomic HistOp MCMem
op = do
AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
let lambda :: Lambda MCMem
lambda = forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp MCMem
op
do_op :: AtomicUpdate MCMem ()
do_op = AtomicBinOp -> Lambda MCMem -> AtomicUpdate MCMem ()
atomicUpdateLocking AtomicBinOp
atomics Lambda MCMem
lambda
case AtomicUpdate MCMem ()
do_op of
AtomicPrim [VName] -> [TExp Int64] -> MulticoreGen ()
f -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName] -> [TExp Int64] -> MulticoreGen ()
f
AtomicCAS [VName] -> [TExp Int64] -> MulticoreGen ()
f -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName] -> [TExp Int64] -> MulticoreGen ()
f
AtomicLocking Locking -> [VName] -> [TExp Int64] -> MulticoreGen ()
f -> do
let num_locks :: Int
num_locks = Int
100151
dims :: [TExp Int64]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims (forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histOpShape HistOp MCMem
op forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
op)
VName
locks <-
forall {k} (rep :: k) r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"hist_locks" Space
DefaultSpace PrimType
int32 forall a b. (a -> b) -> a -> b
$
Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall e. IntegralExp e => e -> e -> e
`rem` forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Locking -> [VName] -> [TExp Int64] -> MulticoreGen ()
f Locking
l'
atomicHistogram ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
MulticoreGen ()
atomicHistogram :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen ()
atomicHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
let num_red_res :: Int
num_red_res = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
([PatElem LParamMem]
all_red_pes, [PatElem LParamMem]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
[[VName] -> [TExp Int64] -> MulticoreGen ()]
atomicOps <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp MCMem
-> MulticoreGen ([VName] -> [TExp Int64] -> MulticoreGen ())
onOpAtomic [HistOp MCMem]
histops
Code Multicore
body <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegHist" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$ \TExp Int64
flat_idx -> do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
flat_idx
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) =
forall a. Int -> [a] -> ([a], [a])
splitFromEnd (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
red_res_split :: [([SubExp], [SubExp])]
red_res_split = forall {k} (rep :: k).
[HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
let pes_per_op :: [[PatElem LParamMem]]
pes_per_op = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [VName]
histDest) [HistOp MCMem]
histops) [PatElem LParamMem]
all_red_pes
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [HistOp MCMem]
histops [([SubExp], [SubExp])]
red_res_split [[VName] -> [TExp Int64] -> MulticoreGen ()]
atomicOps [[PatElem LParamMem]]
pes_per_op) forall a b. (a -> b) -> a -> b
$
\(HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda MCMem
lam, ([SubExp]
bucket, [SubExp]
vs'), [VName] -> [TExp Int64] -> MulticoreGen ()
do_op, [PatElem LParamMem]
dest_res) -> do
let ([Param LParamMem]
_is_params, [Param LParamMem]
vs_params) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
dest_shape' :: [TExp Int64]
dest_shape' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
bucket' :: [TExp Int64]
bucket' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds = Slice (TExp Int64) -> [TExp Int64] -> TPrimExp Bool VName
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [KernelResult]
map_res) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform updates" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds forall a b. (a -> b) -> a -> b
$ do
let bucket_is :: [TExp Int64]
bucket_is = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
is) forall a. [a] -> [a] -> [a]
++ [TExp Int64]
bucket'
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is' -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res [TExp Int64]
is'
[VName] -> [TExp Int64] -> MulticoreGen ()
do_op (forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem LParamMem]
dest_res) ([TExp Int64]
bucket_is forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is')
[Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams Code Multicore
body
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> Code Multicore -> [Param] -> Multicore
Imp.ParLoop String
"atomic_seg_hist" Code Multicore
body [Param]
free_params
updateHisto ::
HistOp MCMem ->
[VName] ->
[Imp.TExp Int64] ->
Imp.TExp Int64 ->
[Param LParamMem] ->
MulticoreGen ()
updateHisto :: HistOp MCMem
-> [VName]
-> [TExp Int64]
-> TExp Int64
-> [Param LParamMem]
-> MulticoreGen ()
updateHisto HistOp MCMem
op [VName]
arrs [TExp Int64]
bucket TExp Int64
j [Param LParamMem]
uni_acc = do
let bind_acc_params :: ImpM rep r op ()
bind_acc_params =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
uni_acc [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_u, VName
arr) -> do
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
acc_u) [] (VName -> SubExp
Var VName
arr) [TExp Int64]
bucket
op_body :: ImpM MCMem r op ()
op_body = forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp MCMem
op
writeArray :: VName -> SubExp -> MulticoreGen ()
writeArray VName
arr SubExp
val = TExp Int64 -> MulticoreGen (Code Multicore) -> MulticoreGen ()
extractVectorLane TExp Int64
j forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
bucket SubExp
val []
do_hist :: MulticoreGen ()
do_hist = forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> SubExp -> MulticoreGen ()
writeArray [VName]
arrs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp MCMem
op
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Start of body" forall a b. (a -> b) -> a -> b
$ do
forall {k} {rep :: k} {r} {op}. ImpM rep r op ()
bind_acc_params
forall {r} {op}. ImpM MCMem r op ()
op_body
MulticoreGen ()
do_hist
subHistogram ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
TV Int32 ->
KernelBody MCMem ->
MulticoreGen ()
subHistogram :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> TV Int32
-> KernelBody MCMem
-> MulticoreGen ()
subHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops TV Int32
num_histos KernelBody MCMem
kbody = do
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"subHistogram segHist" forall a. Maybe a
Nothing
let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
let pes :: [PatElem LParamMem]
pes = forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
num_red_res :: Int
num_red_res = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
map_pes :: [PatElem LParamMem]
map_pes = forall a. Int -> [a] -> [a]
drop Int
num_red_res [PatElem LParamMem]
pes
per_red_pes :: [[PatElem LParamMem]]
per_red_pes = forall {k} (rep :: k) a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
[[VName]]
global_subhistograms <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp MCMem]
histops forall a b. (a -> b) -> a -> b
$ \HistOp MCMem
histop ->
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall {k} (rep :: k). HistOp rep -> [Type]
histType HistOp MCMem
histop) forall a b. (a -> b) -> a -> b
$ \Type
t -> do
let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape [forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_histos] forall a. Semigroup a => a -> a -> a
<> forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
forall {k} (rep :: k) r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"subhistogram" (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) ShapeBase SubExp
shape Space
DefaultSpace
let tid' :: TExp Int64
tid' = forall a. a -> TPrimExp Int64 a
Imp.le64 forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
Code Multicore
body <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
[[VName]]
local_subhistograms <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElem LParamMem]]
per_red_pes [HistOp MCMem]
histops) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
pes', HistOp MCMem
histop) -> do
[VName]
op_local_subhistograms <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall {k} (rep :: k). HistOp rep -> [Type]
histType HistOp MCMem
histop) forall a b. (a -> b) -> a -> b
$ \Type
t ->
forall {k} (rep :: k) r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"subhistogram" (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) Space
DefaultSpace
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem LParamMem]
pes' [VName]
op_local_subhistograms (forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral HistOp MCMem
histop)) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
hist, SubExp
ne) ->
forall {k} (rep :: k) r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int64
tid' forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0)
(forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
hist [] (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [])
( forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
histop) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
shape_is ->
forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histOpShape HistOp MCMem
histop) forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
hist ([TExp Int64]
shape_is forall a. Semigroup a => a -> a -> a
<> [TExp Int64]
vec_is) SubExp
ne []
)
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
op_local_subhistograms
MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$
String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegRed" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
i
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
red_res, [SubExp]
map_res) =
forall a. Int -> [a] -> ([a], [a])
splitFromEnd (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
res) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
res []
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [HistOp MCMem]
histops [[VName]]
local_subhistograms (forall {k} (rep :: k).
[HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops [SubExp]
red_res)) forall a b. (a -> b) -> a -> b
$
\( histop :: HistOp MCMem
histop@(HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda MCMem
_),
[VName]
histop_subhistograms,
([SubExp]
bucket, [SubExp]
vs')
) -> do
HistOp MCMem
histop' <- HistOp MCMem -> MulticoreGen (HistOp MCMem)
renameHistop HistOp MCMem
histop
let bucket' :: [TExp Int64]
bucket' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
dest_shape' :: [TExp Int64]
dest_shape' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
acc_params' :: [Param LParamMem]
acc_params' = (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp) HistOp MCMem
histop'
vs_params' :: [Param LParamMem]
vs_params' = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp MCMem
histop'
(TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
generateUniformizeLoop forall a b. (a -> b) -> a -> b
$ \TExp Int64
j ->
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform updates" forall a b. (a -> b) -> a -> b
$ do
[TV Int64]
extract_buckets <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"extract_bucket" forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall v. PrimExp v -> PrimType
primExpType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped)) [TExp Int64]
bucket'
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [TV Int64]
extract_buckets [TExp Int64]
bucket') forall a b. (a -> b) -> a -> b
$ \(TV Int64
x, TExp Int64
y) ->
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Exp -> Multicore
Imp.ExtractLane (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
x) (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
y) (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
j)
let bucket'' :: [TExp Int64]
bucket'' = forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k). TV t -> TExp t
tvExp [TV Int64]
extract_buckets
bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds =
Slice (TExp Int64) -> [TExp Int64] -> TPrimExp Bool VName
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket'')) [TExp Int64]
dest_shape'
forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds forall a b. (a -> b) -> a -> b
$ do
HistOp MCMem -> MulticoreGen ()
genHistOpParams HistOp MCMem
histop'
forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is' -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params' [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
forall {f :: * -> *} {shape} {u}.
Applicative f =>
TypeBase shape u -> (PrimType -> f ()) -> f ()
ifPrimType (forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p) forall a b. (a -> b) -> a -> b
$ \PrimType
pt -> do
TV Any
tmp <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"tmp" PrimType
pt
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
tmp) [] SubExp
res [TExp Int64]
is'
TExp Int64 -> MulticoreGen (Code Multicore) -> MulticoreGen ()
extractVectorLane TExp Int64
j forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a. VName -> Exp -> Code a
Imp.SetScalar (forall dec. Param dec -> VName
paramName Param LParamMem
p) (forall v. v -> PrimType -> PrimExp v
Imp.LeafExp (forall {k} (t :: k). TV t -> VName
tvVar TV Any
tmp) PrimType
pt)
HistOp MCMem
-> [VName]
-> [TExp Int64]
-> TExp Int64
-> [Param LParamMem]
-> MulticoreGen ()
updateHisto HistOp MCMem
histop' [VName]
histop_subhistograms ([TExp Int64]
bucket'' forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is') TExp Int64
j [Param LParamMem]
acc_params'
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
global_subhistograms) (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
local_subhistograms)) forall a b. (a -> b) -> a -> b
$
\(VName
global, VName
local) -> forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
global [TExp Int64
tid'] (VName -> SubExp
Var VName
local) []
[Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams Code Multicore
body
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> Code Multicore -> [Param] -> Multicore
Imp.ParLoop String
"seghist_stage_1" Code Multicore
body [Param]
free_params
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_red_pes [[VName]]
global_subhistograms [HistOp MCMem]
histops) forall a b. (a -> b) -> a -> b
$ \([PatElem LParamMem]
red_pes, [VName]
hists, HistOp MCMem
op) -> do
[VName]
bucket_ids <-
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall a. ArrayShape a => a -> Int
shapeRank (forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
op)) (forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id")
VName
subhistogram_id <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"subhistogram_id"
let segred_space :: SegSpace
segred_space =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace (SegSpace -> VName
segFlat SegSpace
space) forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
segment_dims
forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
bucket_ids (forall d. ShapeBase d -> [d]
shapeDims (forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
op))
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_histos)]
segred_op :: SegBinOp MCMem
segred_op = forall {k} (rep :: k).
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
Noncommutative (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp MCMem
op) (forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral HistOp MCMem
op) (forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histOpShape HistOp MCMem
op)
Code Multicore
red_code <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
TV Int32
nsubtasks <- forall {k1} {k2} (rep :: k1) r op (t :: k2).
String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"nsubtasks" PrimType
int32
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetNumTasks forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> VName
tvVar TV Int32
nsubtasks
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Pat LParamMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen (Code Multicore)
compileSegRed' (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
red_pes) SegSpace
segred_space [SegBinOp MCMem
segred_op] TV Int32
nsubtasks forall a b. (a -> b) -> a -> b
$ \[[(SubExp, [TExp Int64])]] -> MulticoreGen ()
red_cont ->
[[(SubExp, [TExp Int64])]] -> MulticoreGen ()
red_cont forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem
segred_op] forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> [a] -> [b]
map [VName]
hists forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
( VName -> SubExp
Var VName
subhisto,
forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id] forall a. [a] -> [a] -> [a]
++ [VName]
bucket_ids
)
let ns_red :: [TExp Int64]
ns_red = forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
segred_space
iterations :: TExp Int64
iterations = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [TExp Int64]
ns_red
scheduler_info :: SchedulerInfo
scheduler_info = Exp -> Scheduling -> SchedulerInfo
Imp.SchedulerInfo (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iterations) Scheduling
Imp.Static
red_task :: ParallelTask
red_task = Code Multicore -> ParallelTask
Imp.ParallelTask Code Multicore
red_code
[Param]
free_params_red <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams Code Multicore
red_code
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String
-> [Param]
-> ParallelTask
-> Maybe ParallelTask
-> [Param]
-> SchedulerInfo
-> Multicore
Imp.SegOp String
"seghist_red" [Param]
free_params_red ParallelTask
red_task forall a. Maybe a
Nothing forall a. Monoid a => a
mempty SchedulerInfo
scheduler_info
where
segment_dims :: [(VName, SubExp)]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ifPrimType :: TypeBase shape u -> (PrimType -> f ()) -> f ()
ifPrimType (Prim PrimType
pt) PrimType -> f ()
f = PrimType -> f ()
f PrimType
pt
ifPrimType TypeBase shape u
_ PrimType -> f ()
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
segmentedHist ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
MulticoreGen Imp.MCCode
segmentedHist :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
segmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Segmented segHist" forall a. Maybe a
Nothing
forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
Code Multicore
body <- Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
compileSegHistBody Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody
[Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams Code Multicore
body
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ String -> Code Multicore -> [Param] -> Multicore
Imp.ParLoop String
"segmented_hist" Code Multicore
body [Param]
free_params
compileSegHistBody ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
MulticoreGen Imp.MCCode
compileSegHistBody :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
compileSegHistBody Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
let num_red_res :: Int
num_red_res = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
map_pes :: [PatElem LParamMem]
map_pes = forall a. Int -> [a] -> [a]
drop Int
num_red_res forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
per_red_pes :: [[PatElem LParamMem]]
per_red_pes = forall {k} (rep :: k) a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegHist" ChunkLoopVectorization
Scalar forall a b. (a -> b) -> a -> b
$ \TExp Int64
idx -> do
let inner_bound :: TExp Int64
inner_bound = forall a. [a] -> a
last [TExp Int64]
ns_64
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
inner_bound forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a. [a] -> [a]
init [VName]
is) forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. [a] -> [a]
init [TExp Int64]
ns_64) TExp Int64
idx
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a. [a] -> a
last [VName]
is) TExp Int64
i
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
red_res, [SubExp]
map_res) =
forall a. Int -> [a] -> ([a], [a])
splitFromEnd (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_red_pes [HistOp MCMem]
histops (forall {k} (rep :: k).
[HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops [SubExp]
red_res)) forall a b. (a -> b) -> a -> b
$
\([PatElem LParamMem]
red_pes, HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda MCMem
lam, ([SubExp]
bucket, [SubExp]
vs')) -> do
let ([Param LParamMem]
is_params, [Param LParamMem]
vs_params) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
bucket' :: [TExp Int64]
bucket' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
dest_shape' :: [TExp Int64]
dest_shape' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds = Slice (TExp Int64) -> [TExp Int64] -> TPrimExp Bool VName
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
res) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
res []
forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform updates" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds forall a b. (a -> b) -> a -> b
$ do
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes [Param LParamMem]
is_params) forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, Param LParamMem
p) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
(forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
is) forall a. [a] -> [a] -> [a]
++ [TExp Int64]
bucket' forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TExp Int64]
vec_is
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda MCMem
lam) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda MCMem
lam) forall a b. (a -> b) -> a -> b
$
\(PatElem LParamMem
pe, SubExp
se) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
(forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
is) forall a. [a] -> [a] -> [a]
++ [TExp Int64]
bucket' forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
SubExp
se
[]