module Futhark.CodeGen.ImpGen.Multicore.SegHist
  ( compileSegHist,
  )
where

import Control.Monad
import Data.List (zip4, zip5)
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.CodeGen.ImpGen.Multicore.SegRed (compileSegRed')
import Futhark.IR.MCMem
import Futhark.MonadFreshNames
import Futhark.Util (chunks, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (rem)
import Prelude hiding (quot, rem)

compileSegHist ::
  Pattern MCMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.Code
compileSegHist :: Pattern MCMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen Code
compileSegHist Pattern MCMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
  | [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
    Pattern MCMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen Code
nonsegmentedHist Pattern MCMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
  | Bool
otherwise =
    Pattern MCMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen Code
segmentedHist Pattern MCMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody

-- | Split some list into chunks equal to the number of values
-- returned by each 'SegBinOp'
segHistOpChunks :: [HistOp lore] -> [a] -> [[a]]
segHistOpChunks :: forall lore a. [HistOp lore] -> [a] -> [[a]]
segHistOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([HistOp lore] -> [Int]) -> [HistOp lore] -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HistOp lore -> Int) -> [HistOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp lore -> [SubExp]) -> HistOp lore -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp lore -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral)

nonsegmentedHist ::
  Pattern MCMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.Code
nonsegmentedHist :: Pattern MCMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen Code
nonsegmentedHist Pattern MCMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
num_histos = do
  let ns :: [SubExp]
ns = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
      num_histos' :: TExp Int32
num_histos' = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_histos
      hist_width :: TExp Int64
hist_width = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp MCMem -> SubExp) -> HistOp MCMem -> SubExp
forall a b. (a -> b) -> a -> b
$ [HistOp MCMem] -> HistOp MCMem
forall a. [a] -> a
head [HistOp MCMem]
histops
      use_subhistogram :: TPrimExp Bool ExpLeaf
use_subhistogram = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_histos' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_width TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64

  [HistOp MCMem]
histops' <- [HistOp MCMem] -> MulticoreGen [HistOp MCMem]
renameHistOpLambda [HistOp MCMem]
histops

  -- Only do something if there is actually input.
  ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$
    TPrimExp Bool ExpLeaf
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sUnless ([TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64 TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
      TV Int64
flat_idx <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"iter" PrimType
int64
      TPrimExp Bool ExpLeaf
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
TPrimExp Bool ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
        TPrimExp Bool ExpLeaf
use_subhistogram
        (Pattern MCMem
-> TV Int64
-> SegSpace
-> [HistOp MCMem]
-> TV Int32
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
subHistogram Pattern MCMem
pat TV Int64
flat_idx SegSpace
space [HistOp MCMem]
histops TV Int32
num_histos KernelBody MCMem
kbody)
        (Pattern MCMem
-> TV Int64
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
atomicHistogram Pattern MCMem
pat TV Int64
flat_idx SegSpace
space [HistOp MCMem]
histops' KernelBody MCMem
kbody)

-- |
-- Atomic Histogram approach
-- The implementation has three sub-strategies depending on the
-- type of the operator
-- 1. If values are integral scalars, a direct-supported atomic update is used.
-- 2. If values are on one memory location, e.g. a float, then a
-- CAS operation is used to perform the update, where the float is
-- casted to an integral scalar.
-- 1. and 2. currently only works for 32-bit and 64-bit types,
-- but GCC has support for 8-, 16- and 128- bit types as well.
-- 3. Otherwise a locking based approach is used
onOpAtomic :: HistOp MCMem -> MulticoreGen ([VName] -> [Imp.TExp Int64] -> MulticoreGen ())
onOpAtomic :: HistOp MCMem
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
onOpAtomic HistOp MCMem
op = do
  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM MCMem HostEnv Multicore HostEnv
-> ImpM MCMem HostEnv Multicore AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM MCMem HostEnv Multicore HostEnv
forall lore r op. ImpM lore r op r
askEnv
  let lambda :: Lambda MCMem
lambda = HistOp MCMem -> Lambda MCMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp MCMem
op
      do_op :: AtomicUpdate MCMem ()
do_op = AtomicBinOp -> Lambda MCMem -> AtomicUpdate MCMem ()
atomicUpdateLocking AtomicBinOp
atomics Lambda MCMem
lambda
  case AtomicUpdate MCMem ()
do_op of
    AtomicPrim [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f -> ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
forall (m :: * -> *) a. Monad m => a -> m a
return [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f
    AtomicCAS [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f -> ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
forall (m :: * -> *) a. Monad m => a -> m a
return [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f
    AtomicLocking Locking
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f -> do
      -- Allocate a static array of locks
      -- as in the GPU backend
      let num_locks :: Int
num_locks = Int
100151 -- This number is taken from the GPU backend
          dims :: [TExp Int64]
dims =
            (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
              Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp MCMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp MCMem
op) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [HistOp MCMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp MCMem
op]
      VName
locks <-
        String
-> Space
-> PrimType
-> ArrayContents
-> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"hist_locks" Space
DefaultSpace PrimType
int32 (ArrayContents -> ImpM MCMem HostEnv Multicore VName)
-> ArrayContents -> ImpM MCMem HostEnv Multicore VName
forall a b. (a -> b) -> a -> b
$
          Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
      let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TExp Int64 -> [TExp Int64]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64 -> [TExp Int64])
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TExp Int64 -> TExp Int64)
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
      ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
forall (m :: * -> *) a. Monad m => a -> m a
return (([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> MulticoreGen
      ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()))
-> ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
forall a b. (a -> b) -> a -> b
$ Locking
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f Locking
l'

atomicHistogram ::
  Pattern MCMem ->
  TV Int64 ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen ()
atomicHistogram :: Pattern MCMem
-> TV Int64
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
atomicHistogram Pattern MCMem
pat TV Int64
flat_idx SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
  let num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
      ([PatElemT LParamMem]
all_red_pes, [PatElemT LParamMem]
map_pes) = Int
-> [PatElemT LParamMem]
-> ([PatElemT LParamMem], [PatElemT LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElemT LParamMem]
 -> ([PatElemT LParamMem], [PatElemT LParamMem]))
-> [PatElemT LParamMem]
-> ([PatElemT LParamMem], [PatElemT LParamMem])
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat

  [[VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()]
atomicOps <- (HistOp MCMem
 -> MulticoreGen
      ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()))
-> [HistOp MCMem]
-> ImpM
     MCMem
     HostEnv
     Multicore
     [[VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp MCMem
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
onOpAtomic [HistOp MCMem]
histops

  Code
body <- ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    (VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx
    Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElemT LParamMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
          perOp :: [a] -> [[a]]
perOp = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]]) -> [Int] -> [a] -> [[a]]
forall a b. (a -> b) -> a -> b
$ (HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp MCMem -> [VName]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp MCMem]
histops
          ([KernelResult]
buckets, [KernelResult]
vs) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops) [KernelResult]
red_res

      let pes_per_op :: [[PatElemT LParamMem]]
pes_per_op = [Int] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp MCMem -> [VName]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp MCMem]
histops) [PatElemT LParamMem]
all_red_pes
      [(HistOp MCMem, [KernelResult], KernelResult,
  [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
  [PatElemT LParamMem])]
-> ((HistOp MCMem, [KernelResult], KernelResult,
     [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
     [PatElemT LParamMem])
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp MCMem]
-> [[KernelResult]]
-> [KernelResult]
-> [[VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()]
-> [[PatElemT LParamMem]]
-> [(HistOp MCMem, [KernelResult], KernelResult,
     [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
     [PatElemT LParamMem])]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 [HistOp MCMem]
histops ([KernelResult] -> [[KernelResult]]
forall {a}. [a] -> [[a]]
perOp [KernelResult]
vs) [KernelResult]
buckets [[VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()]
atomicOps [[PatElemT LParamMem]]
pes_per_op) (((HistOp MCMem, [KernelResult], KernelResult,
   [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
   [PatElemT LParamMem])
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((HistOp MCMem, [KernelResult], KernelResult,
     [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
     [PatElemT LParamMem])
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
        \(HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda MCMem
lam, [KernelResult]
vs', KernelResult
bucket, [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
do_op, [PatElemT LParamMem]
dest_res) -> do
          let ([Param LParamMem]
_is_params, [Param LParamMem]
vs_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
vs') ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
lam
              dest_w' :: TExp Int64
dest_w' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
dest_w
              bucket' :: TExp Int64
bucket' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelResult -> SubExp
kernelResultSubExp KernelResult
bucket
              bucket_in_bounds :: TPrimExp Bool ExpLeaf
bucket_in_bounds = TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
dest_w' TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
0 TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
bucket'

          String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
            [(PatElemT LParamMem, KernelResult)]
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [KernelResult] -> [(PatElemT LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
map_pes [KernelResult]
map_res) (((PatElemT LParamMem, KernelResult)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, KernelResult
res) ->
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

          String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform updates" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
            TPrimExp Bool ExpLeaf
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen TPrimExp Bool ExpLeaf
bucket_in_bounds (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
              let bucket_is :: [TExp Int64]
bucket_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64
bucket']
              [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam MCMem] -> ImpM MCMem HostEnv Multicore ())
-> [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
lam
              Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is' -> do
                [(Param LParamMem, KernelResult)]
-> ((Param LParamMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [KernelResult] -> [(Param LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [KernelResult]
vs') (((Param LParamMem, KernelResult)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
res) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) [TExp Int64]
is'
                [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
do_op ((PatElemT LParamMem -> VName) -> [PatElemT LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT LParamMem]
dest_res) ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is')

  [Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams Code
body (SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_idx])
  Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"atomic_seg_hist" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_idx) Code
forall a. Monoid a => a
mempty Code
body Code
forall a. Monoid a => a
mempty [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space

updateHisto :: HistOp MCMem -> [VName] -> [Imp.TExp Int64] -> MulticoreGen ()
updateHisto :: HistOp MCMem
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
updateHisto HistOp MCMem
op [VName]
arrs [TExp Int64]
bucket = do
  let acc_params :: [Param LParamMem]
acc_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (Lambda MCMem -> [LParam MCMem]) -> Lambda MCMem -> [LParam MCMem]
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp MCMem
op
      bind_acc_params :: ImpM lore r op ()
bind_acc_params =
        [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
acc_params [VName]
arrs) (((Param LParamMem, VName) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((Param LParamMem, VName) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_p, VName
arr) ->
          VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
acc_p) [] (VName -> SubExp
Var VName
arr) [TExp Int64]
bucket
      op_body :: ImpM MCMem r op ()
op_body = [Param Any] -> Body MCMem -> ImpM MCMem r op ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [] (Body MCMem -> ImpM MCMem r op ())
-> Body MCMem -> ImpM MCMem r op ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (Lambda MCMem -> Body MCMem) -> Lambda MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp MCMem
op
      writeArray :: VName -> SubExp -> ImpM lore r op ()
writeArray VName
arr SubExp
val = VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int64]
bucket SubExp
val []
      do_hist :: ImpM lore r op ()
do_hist = (VName -> SubExp -> ImpM lore r op ())
-> [VName] -> [SubExp] -> ImpM lore r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> SubExp -> ImpM lore r op ()
forall {lore} {r} {op}. VName -> SubExp -> ImpM lore r op ()
writeArray [VName]
arrs ([SubExp] -> ImpM lore r op ()) -> [SubExp] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (Lambda MCMem -> Body MCMem) -> Lambda MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp MCMem
op

  String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Start of body" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
    [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams [LParam MCMem]
[Param LParamMem]
acc_params
    ImpM MCMem HostEnv Multicore ()
forall {lore} {r} {op}. ImpM lore r op ()
bind_acc_params
    ImpM MCMem HostEnv Multicore ()
forall {r} {op}. ImpM MCMem r op ()
op_body
    ImpM MCMem HostEnv Multicore ()
forall {lore} {r} {op}. ImpM lore r op ()
do_hist

-- Generates num_histos sub-histograms of the size
-- of the destination histogram
-- Then for each chunk of the input each subhistogram
-- is computed and finally combined through a segmented reduction
-- across the histogram indicies.
-- This is expected to be fast if len(histDest) is small
subHistogram ::
  Pattern MCMem ->
  TV Int64 ->
  SegSpace ->
  [HistOp MCMem] ->
  TV Int32 ->
  KernelBody MCMem ->
  MulticoreGen ()
subHistogram :: Pattern MCMem
-> TV Int64
-> SegSpace
-> [HistOp MCMem]
-> TV Int32
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
subHistogram Pattern MCMem
pat TV Int64
flat_idx SegSpace
space [HistOp MCMem]
histops TV Int32
num_histos KernelBody MCMem
kbody = do
  Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"subHistogram segHist" Maybe Exp
forall a. Maybe a
Nothing

  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns

  let pes :: [PatElemT LParamMem]
pes = PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern MCMem
PatternT LParamMem
pat
      num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
      map_pes :: [PatElemT LParamMem]
map_pes = Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
drop Int
num_red_res [PatElemT LParamMem]
pes
      per_red_pes :: [[PatElemT LParamMem]]
per_red_pes = [HistOp MCMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [HistOp lore] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat

  -- Allocate array of subhistograms in the calling thread.  Each
  -- tasks will work in its own private allocations (to avoid false
  -- sharing), but this is where they will ultimately copy their
  -- results.
  [[VName]]
global_subhistograms <- [HistOp MCMem]
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp MCMem]
histops ((HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
 -> ImpM MCMem HostEnv Multicore [[VName]])
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall a b. (a -> b) -> a -> b
$ \HistOp MCMem
histop ->
    [Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp MCMem -> [Type]
forall lore. HistOp lore -> [Type]
histType HistOp MCMem
histop) ((Type -> ImpM MCMem HostEnv Multicore VName)
 -> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
      let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_histos] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
      String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"subhistogram" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) Shape
shape Space
DefaultSpace

  let tid' :: TExp Int64
tid' = VName -> TExp Int64
Imp.vi64 (VName -> TExp Int64) -> VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
      flat_idx' :: TExp Int64
flat_idx' = TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx

  ([[VName]]
local_subhistograms, Code
prebody) <- ImpM MCMem HostEnv Multicore [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (ImpM MCMem HostEnv Multicore [[VName]]
 -> ImpM MCMem HostEnv Multicore ([[VName]], Code))
-> ImpM MCMem HostEnv Multicore [[VName]]
-> ImpM MCMem HostEnv Multicore ([[VName]], Code)
forall a b. (a -> b) -> a -> b
$ do
    (VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
flat_idx'

    [([PatElemT LParamMem], HistOp MCMem)]
-> (([PatElemT LParamMem], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([[PatElemT LParamMem]]
-> [HistOp MCMem] -> [([PatElemT LParamMem], HistOp MCMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElemT LParamMem]]
per_red_pes [HistOp MCMem]
histops) ((([PatElemT LParamMem], HistOp MCMem)
  -> ImpM MCMem HostEnv Multicore [VName])
 -> ImpM MCMem HostEnv Multicore [[VName]])
-> (([PatElemT LParamMem], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall a b. (a -> b) -> a -> b
$ \([PatElemT LParamMem]
pes', HistOp MCMem
histop) -> do
      [VName]
op_local_subhistograms <- [Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp MCMem -> [Type]
forall lore. HistOp lore -> [Type]
histType HistOp MCMem
histop) ((Type -> ImpM MCMem HostEnv Multicore VName)
 -> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
        String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"subhistogram" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) Space
DefaultSpace

      [(PatElemT LParamMem, VName, SubExp)]
-> ((PatElemT LParamMem, VName, SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [VName] -> [SubExp] -> [(PatElemT LParamMem, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT LParamMem]
pes' [VName]
op_local_subhistograms (HistOp MCMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp MCMem
histop)) (((PatElemT LParamMem, VName, SubExp)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElemT LParamMem, VName, SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, VName
hist, SubExp
ne) ->
        -- First thread initializes histogram with dest vals. Others
        -- initialize with neutral element
        TPrimExp Bool ExpLeaf
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
TPrimExp Bool ExpLeaf
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
          (TExp Int64
tid' TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0)
          (VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
hist [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) [])
          ( String
-> TExp Int64
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp MCMem
histop) ((TExp Int64 -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i ->
              Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (HistOp MCMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp MCMem
histop) (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
hist (TExp Int64
i TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
vec_is) SubExp
ne []
          )

      [VName] -> ImpM MCMem HostEnv Multicore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
op_local_subhistograms

  -- Generate loop body of parallel function
  Code
body <- ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    (VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
is ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
flat_idx'
    Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElemT LParamMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
          ([KernelResult]
buckets, [KernelResult]
vs) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops) [KernelResult]
red_res
          perOp :: [a] -> [[a]]
perOp = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]]) -> [Int] -> [a] -> [[a]]
forall a b. (a -> b) -> a -> b
$ (HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp MCMem -> [VName]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp MCMem]
histops

      String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
        [(PatElemT LParamMem, KernelResult)]
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [KernelResult] -> [(PatElemT LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
map_pes [KernelResult]
map_res) (((PatElemT LParamMem, KernelResult)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElemT LParamMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, KernelResult
res) ->
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
            (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
            ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is)
            (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
            []

      [(HistOp MCMem, [VName], KernelResult, [KernelResult])]
-> ((HistOp MCMem, [VName], KernelResult, [KernelResult])
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp MCMem]
-> [[VName]]
-> [KernelResult]
-> [[KernelResult]]
-> [(HistOp MCMem, [VName], KernelResult, [KernelResult])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [HistOp MCMem]
histops [[VName]]
local_subhistograms [KernelResult]
buckets ([KernelResult] -> [[KernelResult]]
forall {a}. [a] -> [[a]]
perOp [KernelResult]
vs)) (((HistOp MCMem, [VName], KernelResult, [KernelResult])
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((HistOp MCMem, [VName], KernelResult, [KernelResult])
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
        \( histop :: HistOp MCMem
histop@(HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda MCMem
lam),
           [VName]
histop_subhistograms,
           KernelResult
bucket,
           [KernelResult]
vs'
           ) -> do
            let bucket' :: TExp Int64
bucket' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelResult -> SubExp
kernelResultSubExp KernelResult
bucket
                dest_w' :: TExp Int64
dest_w' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
dest_w
                bucket_in_bounds :: TPrimExp Bool ExpLeaf
bucket_in_bounds = TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
dest_w' TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
0 TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
bucket'
                vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
lam
                bucket_is :: [TExp Int64]
bucket_is = [TExp Int64
bucket']

            String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform updates" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
              TPrimExp Bool ExpLeaf
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen TPrimExp Bool ExpLeaf
bucket_in_bounds (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
                [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam MCMem] -> ImpM MCMem HostEnv Multicore ())
-> [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
lam
                Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is' -> do
                  [(Param LParamMem, KernelResult)]
-> ((Param LParamMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [KernelResult] -> [(Param LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [KernelResult]
vs') (((Param LParamMem, KernelResult)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
res) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) [TExp Int64]
is'
                  HistOp MCMem
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
updateHisto HistOp MCMem
histop [VName]
histop_subhistograms ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is')

  -- Copy the task-local subhistograms to the global subhistograms,
  -- where they will be combined.
  Code
postbody <- ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$
    [(VName, VName)]
-> ((VName, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
global_subhistograms) ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
local_subhistograms)) (((VName, VName) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((VName, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
      \(VName
global, VName
local) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
global [TExp Int64
tid'] (VName -> SubExp
Var VName
local) []

  [Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams (Code
prebody Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
body Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
postbody) (SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_idx])
  let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
  Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"seghist_stage_1" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_idx) (Code
body_allocs Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
prebody) Code
body' Code
postbody [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space

  -- Perform a segmented reduction over the subhistograms
  [([PatElemT LParamMem], [VName], HistOp MCMem)]
-> (([PatElemT LParamMem], [VName], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [[VName]]
-> [HistOp MCMem]
-> [([PatElemT LParamMem], [VName], HistOp MCMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElemT LParamMem]]
per_red_pes [[VName]]
global_subhistograms [HistOp MCMem]
histops) ((([PatElemT LParamMem], [VName], HistOp MCMem)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (([PatElemT LParamMem], [VName], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \([PatElemT LParamMem]
red_pes, [VName]
hists, HistOp MCMem
op) -> do
    VName
bucket_id <- String -> ImpM MCMem HostEnv Multicore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id"
    VName
subhistogram_id <- String -> ImpM MCMem HostEnv Multicore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"subhistogram_id"

    let num_buckets :: SubExp
num_buckets = HistOp MCMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp MCMem
op
        segred_space :: SegSpace
segred_space =
          VName -> [(VName, SubExp)] -> SegSpace
SegSpace (SegSpace -> VName
segFlat SegSpace
space) ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
            [(VName, SubExp)]
segment_dims
              [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
bucket_id, SubExp
num_buckets)]
              [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_histos)]

        segred_op :: SegBinOp MCMem
segred_op = Commutativity
-> Lambda MCMem -> [SubExp] -> Shape -> SegBinOp MCMem
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
Noncommutative (HistOp MCMem -> Lambda MCMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp MCMem
op) (HistOp MCMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp MCMem
op) (HistOp MCMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp MCMem
op)

    TV Int32
nsubtasks_red <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"num_tasks" (PrimType -> ImpM MCMem HostEnv Multicore (TV Int32))
-> PrimType -> ImpM MCMem HostEnv Multicore (TV Int32)
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32
    Code
red_code <- Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen Code
compileSegRed' ([PatElemT LParamMem] -> [PatElemT LParamMem] -> PatternT LParamMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT LParamMem]
red_pes) SegSpace
segred_space [SegBinOp MCMem
segred_op] TV Int32
nsubtasks_red (DoSegBody -> MulticoreGen Code) -> DoSegBody -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> ImpM MCMem HostEnv Multicore ()
red_cont ->
      [(SubExp, [TExp Int64])] -> ImpM MCMem HostEnv Multicore ()
red_cont ([(SubExp, [TExp Int64])] -> ImpM MCMem HostEnv Multicore ())
-> [(SubExp, [TExp Int64])] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
        ((VName -> (SubExp, [TExp Int64]))
 -> [VName] -> [(SubExp, [TExp Int64])])
-> [VName]
-> (VName -> (SubExp, [TExp Int64]))
-> [(SubExp, [TExp Int64])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
hists ((VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])])
-> (VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
          ( VName -> SubExp
Var VName
subhisto,
            (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [TExp Int64]) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
              ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id, VName
bucket_id]
          )

    let ns_red :: [TExp Int64]
ns_red = ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
segred_space
        iterations :: TExp Int64
iterations = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
ns_red -- The segmented reduction is sequential over the inner most dimension
        scheduler_info :: SchedulerInfo
scheduler_info = VName -> Exp -> Scheduling -> SchedulerInfo
Imp.SchedulerInfo (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
nsubtasks_red) (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iterations) Scheduling
Imp.Static
        red_task :: ParallelTask
red_task = Code -> VName -> ParallelTask
Imp.ParallelTask Code
red_code (VName -> ParallelTask) -> VName -> ParallelTask
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
    [Param]
free_params_red <- Code -> [VName] -> MulticoreGen [Param]
freeParams Code
red_code [SegSpace -> VName
segFlat SegSpace
space, TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
nsubtasks_red]
    Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> [Param]
-> ParallelTask
-> Maybe ParallelTask
-> [Param]
-> SchedulerInfo
-> Multicore
Imp.Segop String
"seghist_red" [Param]
free_params_red ParallelTask
red_task Maybe ParallelTask
forall a. Maybe a
Nothing [Param]
forall a. Monoid a => a
mempty SchedulerInfo
scheduler_info
  where
    segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

-- This implementation for a Segmented Hist only
-- parallelize over the segments,
-- where each segment is updated sequentially.
segmentedHist ::
  Pattern MCMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen Imp.Code
segmentedHist :: Pattern MCMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen Code
segmentedHist Pattern MCMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
  Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Segmented segHist" Maybe Exp
forall a. Maybe a
Nothing
  -- Iteration variable over the segments
  TV Int64
segments_i <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"segment_iter" (PrimType -> ImpM MCMem HostEnv Multicore (TV Int64))
-> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
  ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    Code
par_body <- TExp Int64
-> Pattern MCMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen Code
compileSegHistBody (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
segments_i) Pattern MCMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody
    [Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams Code
par_body [SegSpace -> VName
segFlat SegSpace
space, TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
segments_i]
    let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
par_body
    Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
"segmented_hist" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
segments_i) Code
body_allocs Code
body' Code
forall a. Monoid a => a
mempty [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space

compileSegHistBody ::
  Imp.TExp Int64 ->
  Pattern MCMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen Imp.Code
compileSegHistBody :: TExp Int64
-> Pattern MCMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen Code
compileSegHistBody TExp Int64
idx Pattern MCMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns

  let num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
      map_pes :: [PatElemT LParamMem]
map_pes = Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
drop Int
num_red_res ([PatElemT LParamMem] -> [PatElemT LParamMem])
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat
      per_red_pes :: [[PatElemT LParamMem]]
per_red_pes = [HistOp MCMem] -> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall lore a. [HistOp lore] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops ([PatElemT LParamMem] -> [[PatElemT LParamMem]])
-> [PatElemT LParamMem] -> [[PatElemT LParamMem]]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern MCMem
PatternT LParamMem
pat

  ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen Code)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    let inner_bound :: TExp Int64
inner_bound = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
ns_64
    String
-> TExp Int64
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int64
inner_bound ((TExp Int64 -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
      (VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
ns_64) TExp Int64
idx
      VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
is) TExp Int64
i

      Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
        let ([SubExp]
red_res, [SubExp]
map_res) =
              Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElemT LParamMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
            ([SubExp]
buckets, [SubExp]
vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops) [SubExp]
red_res
            perOp :: [a] -> [[a]]
perOp = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]]) -> [Int] -> [a] -> [[a]]
forall a b. (a -> b) -> a -> b
$ (HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp MCMem -> [VName]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp MCMem]
histops

        [([PatElemT LParamMem], HistOp MCMem, [SubExp], SubExp)]
-> (([PatElemT LParamMem], HistOp MCMem, [SubExp], SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElemT LParamMem]]
-> [HistOp MCMem]
-> [[SubExp]]
-> [SubExp]
-> [([PatElemT LParamMem], HistOp MCMem, [SubExp], SubExp)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[PatElemT LParamMem]]
per_red_pes [HistOp MCMem]
histops ([SubExp] -> [[SubExp]]
forall {a}. [a] -> [[a]]
perOp [SubExp]
vs) [SubExp]
buckets) ((([PatElemT LParamMem], HistOp MCMem, [SubExp], SubExp)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (([PatElemT LParamMem], HistOp MCMem, [SubExp], SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
          \([PatElemT LParamMem]
red_pes, HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda MCMem
lam, [SubExp]
vs', SubExp
bucket) -> do
            let ([Param LParamMem]
is_params, [Param LParamMem]
vs_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
lam
                bucket' :: TExp Int64
bucket' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
bucket
                dest_w' :: TExp Int64
dest_w' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
dest_w
                bucket_in_bounds :: TPrimExp Bool ExpLeaf
bucket_in_bounds = TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
dest_w' TPrimExp Bool ExpLeaf
-> TPrimExp Bool ExpLeaf -> TPrimExp Bool ExpLeaf
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
0 TExp Int64 -> TExp Int64 -> TPrimExp Bool ExpLeaf
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
bucket'

            String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
              [(PatElemT LParamMem, SubExp)]
-> ((PatElemT LParamMem, SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [SubExp] -> [(PatElemT LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
map_pes [SubExp]
map_res) (((PatElemT LParamMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElemT LParamMem, SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, SubExp
res) ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is) SubExp
res []

            String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform updates" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
              TPrimExp Bool ExpLeaf
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
TPrimExp Bool ExpLeaf -> ImpM lore r op () -> ImpM lore r op ()
sWhen TPrimExp Bool ExpLeaf
bucket_in_bounds (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
                [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam MCMem] -> ImpM MCMem HostEnv Multicore ())
-> [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
lam
                Shape
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Shape -> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
                  -- Index
                  let buck :: TExp Int64
buck = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
bucket
                  [(PatElemT LParamMem, Param LParamMem)]
-> ((PatElemT LParamMem, Param LParamMem)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem]
-> [Param LParamMem] -> [(PatElemT LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
red_pes [Param LParamMem]
is_params) (((PatElemT LParamMem, Param LParamMem)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElemT LParamMem, Param LParamMem)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, Param LParamMem
p) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64
buck] [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
                  -- Value at index
                  [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((Param LParamMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TExp Int64]
vec_is
                  Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall lore. BodyT lore -> Stms lore
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda MCMem
lam) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
                    [(PatElemT LParamMem, SubExp)]
-> ((PatElemT LParamMem, SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [SubExp] -> [(PatElemT LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
red_pes ([SubExp] -> [(PatElemT LParamMem, SubExp)])
-> [SubExp] -> [(PatElemT LParamMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ Body MCMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body MCMem -> [SubExp]) -> Body MCMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda MCMem
lam) (((PatElemT LParamMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElemT LParamMem, SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
                      \(PatElemT LParamMem
pe, SubExp
se) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64
buck] [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is) SubExp
se []