module Futhark.CodeGen.ImpGen.Multicore.SegHist
  ( compileSegHist,
  )
where

import Control.Monad
import Data.List (zip4)
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.CodeGen.ImpGen.Multicore.SegRed (compileSegRed')
import Futhark.IR.MCMem
import Futhark.MonadFreshNames
import Futhark.Util (chunks, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (rem)
import Prelude hiding (quot, rem)

compileSegHist ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.MCCode
compileSegHist :: Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen MCCode
compileSegHist Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
  | [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
    Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen MCCode
nonsegmentedHist Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
  | Bool
otherwise =
    Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen MCCode
segmentedHist Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody

-- | Split some list into chunks equal to the number of values
-- returned by each 'SegBinOp'
segHistOpChunks :: [HistOp rep] -> [a] -> [[a]]
segHistOpChunks :: [HistOp rep] -> [a] -> [[a]]
segHistOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([HistOp rep] -> [Int]) -> [HistOp rep] -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (HistOp rep -> [SubExp]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp rep -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral)

histSize :: HistOp MCMem -> Imp.TExp Int64
histSize :: HistOp MCMem -> TExp Int64
histSize = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64)
-> (HistOp MCMem -> [TExp Int64]) -> HistOp MCMem -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64])
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase SubExp -> [SubExp])
-> (HistOp MCMem -> ShapeBase SubExp) -> HistOp MCMem -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape

nonsegmentedHist ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  TV Int32 ->
  MulticoreGen Imp.MCCode
nonsegmentedHist :: Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen MCCode
nonsegmentedHist Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
num_histos = do
  let ns :: [SubExp]
ns = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
      num_histos' :: TExp Int32
num_histos' = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_histos
      hist_width :: TExp Int64
hist_width = HistOp MCMem -> TExp Int64
histSize (HistOp MCMem -> TExp Int64) -> HistOp MCMem -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [HistOp MCMem] -> HistOp MCMem
forall a. [a] -> a
head [HistOp MCMem]
histops
      use_subhistogram :: TPrimExp Bool VName
use_subhistogram = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_histos' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_width TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64

  [HistOp MCMem]
histops' <- [HistOp MCMem] -> MulticoreGen [HistOp MCMem]
renameHistOpLambda [HistOp MCMem]
histops

  -- Only do something if there is actually input.
  ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall a b. (a -> b) -> a -> b
$
    TPrimExp Bool VName
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sUnless ([TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64 TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
      TPrimExp Bool VName
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        TPrimExp Bool VName
use_subhistogram
        (Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> TV Int32
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
subHistogram Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops TV Int32
num_histos KernelBody MCMem
kbody)
        (Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
atomicHistogram Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops' KernelBody MCMem
kbody)

-- |
-- Atomic Histogram approach
-- The implementation has three sub-strategies depending on the
-- type of the operator
-- 1. If values are integral scalars, a direct-supported atomic update is used.
-- 2. If values are on one memory location, e.g. a float, then a
-- CAS operation is used to perform the update, where the float is
-- casted to an integral scalar.
-- 1. and 2. currently only works for 32-bit and 64-bit types,
-- but GCC has support for 8-, 16- and 128- bit types as well.
-- 3. Otherwise a locking based approach is used
onOpAtomic :: HistOp MCMem -> MulticoreGen ([VName] -> [Imp.TExp Int64] -> MulticoreGen ())
onOpAtomic :: HistOp MCMem
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
onOpAtomic HistOp MCMem
op = do
  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM MCMem HostEnv Multicore HostEnv
-> ImpM MCMem HostEnv Multicore AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM MCMem HostEnv Multicore HostEnv
forall rep r op. ImpM rep r op r
askEnv
  let lambda :: Lambda MCMem
lambda = HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op
      do_op :: AtomicUpdate MCMem ()
do_op = AtomicBinOp -> Lambda MCMem -> AtomicUpdate MCMem ()
atomicUpdateLocking AtomicBinOp
atomics Lambda MCMem
lambda
  case AtomicUpdate MCMem ()
do_op of
    AtomicPrim [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f -> ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
forall (m :: * -> *) a. Monad m => a -> m a
return [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f
    AtomicCAS [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f -> ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
forall (m :: * -> *) a. Monad m => a -> m a
return [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f
    AtomicLocking Locking
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f -> do
      -- Allocate a static array of locks
      -- as in the GPU backend
      let num_locks :: Int
num_locks = Int
100151 -- This number is taken from the GPU backend
          dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp MCMem
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
op)
      VName
locks <-
        String
-> Space
-> PrimType
-> ArrayContents
-> ImpM MCMem HostEnv Multicore VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"hist_locks" Space
DefaultSpace PrimType
int32 (ArrayContents -> ImpM MCMem HostEnv Multicore VName)
-> ArrayContents -> ImpM MCMem HostEnv Multicore VName
forall a b. (a -> b) -> a -> b
$
          Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
      let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TExp Int64 -> [TExp Int64]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64 -> [TExp Int64])
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TExp Int64 -> TExp Int64)
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
      ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
forall (m :: * -> *) a. Monad m => a -> m a
return (([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> MulticoreGen
      ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()))
-> ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
forall a b. (a -> b) -> a -> b
$ Locking
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
f Locking
l'

atomicHistogram ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen ()
atomicHistogram :: Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
atomicHistogram Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
  let num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
      ([PatElem LetDecMem]
all_red_pes, [PatElem LetDecMem]
map_pes) = Int
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElem LetDecMem] -> ([PatElem LetDecMem], [PatElem LetDecMem]))
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat

  [[VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()]
atomicOps <- (HistOp MCMem
 -> MulticoreGen
      ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()))
-> [HistOp MCMem]
-> ImpM
     MCMem
     HostEnv
     Multicore
     [[VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp MCMem
-> MulticoreGen
     ([VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ())
onOpAtomic [HistOp MCMem]
histops

  MCCode
body <- ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall a b. (a -> b) -> a -> b
$ do
    VName -> PrimType -> ImpM MCMem HostEnv Multicore ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    Multicore -> ImpM MCMem HostEnv Multicore ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> ImpM MCMem HostEnv Multicore ())
-> Multicore -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
    String
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegHist" ((TExp Int64 -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
flat_idx -> do
      (VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
flat_idx
      Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
        let ([KernelResult]
red_res, [KernelResult]
map_res) =
              Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LetDecMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
            red_res_split :: [([SubExp], [SubExp])]
red_res_split = [HistOp MCMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops ([SubExp] -> [([SubExp], [SubExp])])
-> [SubExp] -> [([SubExp], [SubExp])]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res

        let pes_per_op :: [[PatElem LetDecMem]]
pes_per_op = [Int] -> [PatElem LetDecMem] -> [[PatElem LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp MCMem -> [VName]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp MCMem]
histops) [PatElem LetDecMem]
all_red_pes
        [(HistOp MCMem, ([SubExp], [SubExp]),
  [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
  [PatElem LetDecMem])]
-> ((HistOp MCMem, ([SubExp], [SubExp]),
     [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
     [PatElem LetDecMem])
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp MCMem]
-> [([SubExp], [SubExp])]
-> [[VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()]
-> [[PatElem LetDecMem]]
-> [(HistOp MCMem, ([SubExp], [SubExp]),
     [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
     [PatElem LetDecMem])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [HistOp MCMem]
histops [([SubExp], [SubExp])]
red_res_split [[VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()]
atomicOps [[PatElem LetDecMem]]
pes_per_op) (((HistOp MCMem, ([SubExp], [SubExp]),
   [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
   [PatElem LetDecMem])
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((HistOp MCMem, ([SubExp], [SubExp]),
     [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore (),
     [PatElem LetDecMem])
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
          \(HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda MCMem
lam, ([SubExp]
bucket, [SubExp]
vs'), [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
do_op, [PatElem LetDecMem]
dest_res) -> do
            let ([Param LetDecMem]
_is_params, [Param LetDecMem]
vs_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
                dest_shape' :: [TExp Int64]
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
                bucket' :: [TExp Int64]
bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
bucket
                bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds = Slice (TExp Int64) -> [TExp Int64] -> TPrimExp Bool VName
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'

            String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
              [(PatElem LetDecMem, KernelResult)]
-> ((PatElem LetDecMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem]
-> [KernelResult] -> [(PatElem LetDecMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
map_pes [KernelResult]
map_res) (((PatElem LetDecMem, KernelResult)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElem LetDecMem, KernelResult)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, KernelResult
res) ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

            String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform updates" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
              TPrimExp Bool VName
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
                let bucket_is :: [TExp Int64]
bucket_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
bucket'
                [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam MCMem] -> ImpM MCMem HostEnv Multicore ())
-> [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
                ShapeBase SubExp
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is' -> do
                  [(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
vs') (((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
res) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
res [TExp Int64]
is'
                  [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
do_op ((PatElem LetDecMem -> VName) -> [PatElem LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem LetDecMem]
dest_res) ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is')

  [Param]
free_params <- MCCode -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
  MCCode -> ImpM MCMem HostEnv Multicore ()
forall op rep r. Code op -> ImpM rep r op ()
emit (MCCode -> ImpM MCMem HostEnv Multicore ())
-> MCCode -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> MCCode
forall a. a -> Code a
Imp.Op (Multicore -> MCCode) -> Multicore -> MCCode
forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"atomic_seg_hist" MCCode
body [Param]
free_params

updateHisto :: HistOp MCMem -> [VName] -> [Imp.TExp Int64] -> MulticoreGen ()
updateHisto :: HistOp MCMem
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
updateHisto HistOp MCMem
op [VName]
arrs [TExp Int64]
bucket = do
  let acc_params :: [Param LetDecMem]
acc_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
take ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda MCMem -> [LParam MCMem]) -> Lambda MCMem -> [LParam MCMem]
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op
      bind_acc_params :: ImpM rep r op ()
bind_acc_params =
        [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
acc_params [VName]
arrs) (((Param LetDecMem, VName) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((Param LetDecMem, VName) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
acc_p, VName
arr) ->
          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
acc_p) [] (VName -> SubExp
Var VName
arr) [TExp Int64]
bucket
      op_body :: ImpM MCMem r op ()
op_body = [Param Any] -> Body MCMem -> ImpM MCMem r op ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [] (Body MCMem -> ImpM MCMem r op ())
-> Body MCMem -> ImpM MCMem r op ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda MCMem -> Body MCMem) -> Lambda MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op
      writeArray :: VName -> SubExp -> ImpM rep r op ()
writeArray VName
arr SubExp
val = VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
bucket SubExp
val []
      do_hist :: ImpM rep r op ()
do_hist = (VName -> SubExp -> ImpM rep r op ())
-> [VName] -> [SubExp] -> ImpM rep r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> SubExp -> ImpM rep r op ()
forall rep r op. VName -> SubExp -> ImpM rep r op ()
writeArray [VName]
arrs ([SubExp] -> ImpM rep r op ()) -> [SubExp] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body MCMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body MCMem -> [SubExpRes]) -> Body MCMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda MCMem -> Body MCMem) -> Lambda MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op

  String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Start of body" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
    [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams [LParam MCMem]
[Param LetDecMem]
acc_params
    ImpM MCMem HostEnv Multicore ()
forall rep r op. ImpM rep r op ()
bind_acc_params
    ImpM MCMem HostEnv Multicore ()
forall r op. ImpM MCMem r op ()
op_body
    ImpM MCMem HostEnv Multicore ()
forall rep r op. ImpM rep r op ()
do_hist

-- Generates num_histos sub-histograms of the size
-- of the destination histogram
-- Then for each chunk of the input each subhistogram
-- is computed and finally combined through a segmented reduction
-- across the histogram indicies.
-- This is expected to be fast if len(histDest) is small
subHistogram ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  TV Int32 ->
  KernelBody MCMem ->
  MulticoreGen ()
subHistogram :: Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> TV Int32
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore ()
subHistogram Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops TV Int32
num_histos KernelBody MCMem
kbody = do
  MCCode -> ImpM MCMem HostEnv Multicore ()
forall op rep r. Code op -> ImpM rep r op ()
emit (MCCode -> ImpM MCMem HostEnv Multicore ())
-> MCCode -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> MCCode
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"subHistogram segHist" Maybe Exp
forall a. Maybe a
Nothing

  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns

  let pes :: [PatElem LetDecMem]
pes = Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat
      num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
      map_pes :: [PatElem LetDecMem]
map_pes = Int -> [PatElem LetDecMem] -> [PatElem LetDecMem]
forall a. Int -> [a] -> [a]
drop Int
num_red_res [PatElem LetDecMem]
pes
      per_red_pes :: [[PatElem LetDecMem]]
per_red_pes = [HistOp MCMem] -> [PatElem LetDecMem] -> [[PatElem LetDecMem]]
forall rep a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops ([PatElem LetDecMem] -> [[PatElem LetDecMem]])
-> [PatElem LetDecMem] -> [[PatElem LetDecMem]]
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat

  -- Allocate array of subhistograms in the calling thread.  Each
  -- tasks will work in its own private allocations (to avoid false
  -- sharing), but this is where they will ultimately copy their
  -- results.
  [[VName]]
global_subhistograms <- [HistOp MCMem]
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp MCMem]
histops ((HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
 -> ImpM MCMem HostEnv Multicore [[VName]])
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall a b. (a -> b) -> a -> b
$ \HistOp MCMem
histop ->
    [Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp MCMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp MCMem
histop) ((Type -> ImpM MCMem HostEnv Multicore VName)
 -> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
      let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_histos] ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
      String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM MCMem HostEnv Multicore VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"subhistogram" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) ShapeBase SubExp
shape Space
DefaultSpace

  let tid' :: TExp Int64
tid' = VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TExp Int64) -> VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space

  -- Generate loop body of parallel function
  MCCode
body <- ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall a b. (a -> b) -> a -> b
$ do
    VName -> PrimType -> ImpM MCMem HostEnv Multicore ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
    Multicore -> ImpM MCMem HostEnv Multicore ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> ImpM MCMem HostEnv Multicore ())
-> Multicore -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

    [[VName]]
local_subhistograms <- [([PatElem LetDecMem], HistOp MCMem)]
-> (([PatElem LetDecMem], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([[PatElem LetDecMem]]
-> [HistOp MCMem] -> [([PatElem LetDecMem], HistOp MCMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[PatElem LetDecMem]]
per_red_pes [HistOp MCMem]
histops) ((([PatElem LetDecMem], HistOp MCMem)
  -> ImpM MCMem HostEnv Multicore [VName])
 -> ImpM MCMem HostEnv Multicore [[VName]])
-> (([PatElem LetDecMem], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall a b. (a -> b) -> a -> b
$ \([PatElem LetDecMem]
pes', HistOp MCMem
histop) -> do
      [VName]
op_local_subhistograms <- [Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp MCMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp MCMem
histop) ((Type -> ImpM MCMem HostEnv Multicore VName)
 -> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
        String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM MCMem HostEnv Multicore VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"subhistogram" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) Space
DefaultSpace

      [(PatElem LetDecMem, VName, SubExp)]
-> ((PatElem LetDecMem, VName, SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem]
-> [VName] -> [SubExp] -> [(PatElem LetDecMem, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem LetDecMem]
pes' [VName]
op_local_subhistograms (HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp MCMem
histop)) (((PatElem LetDecMem, VName, SubExp)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElem LetDecMem, VName, SubExp)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
hist, SubExp
ne) ->
        -- First thread initializes histogram with dest vals. Others
        -- initialize with neutral element
        TPrimExp Bool VName
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
          (TExp Int64
tid' TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0)
          (VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
hist [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [])
          ( ShapeBase SubExp
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
histop) (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
shape_is ->
              ShapeBase SubExp
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp MCMem
histop) (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
hist ([TExp Int64]
shape_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. Semigroup a => a -> a -> a
<> [TExp Int64]
vec_is) SubExp
ne []
          )

      [VName] -> ImpM MCMem HostEnv Multicore [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
op_local_subhistograms

    String
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegRed" ((TExp Int64 -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
      (VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
i
      Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
        let ([SubExp]
red_res, [SubExp]
map_res) =
              Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LetDecMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody

        String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
          [(PatElem LetDecMem, SubExp)]
-> ((PatElem LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> [SubExp] -> [(PatElem LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
map_pes [SubExp]
map_res) (((PatElem LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElem LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, SubExp
res) ->
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
res []

        [(HistOp MCMem, [VName], ([SubExp], [SubExp]))]
-> ((HistOp MCMem, [VName], ([SubExp], [SubExp]))
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp MCMem]
-> [[VName]]
-> [([SubExp], [SubExp])]
-> [(HistOp MCMem, [VName], ([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [HistOp MCMem]
histops [[VName]]
local_subhistograms ([HistOp MCMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops [SubExp]
red_res)) (((HistOp MCMem, [VName], ([SubExp], [SubExp]))
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((HistOp MCMem, [VName], ([SubExp], [SubExp]))
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
          \( histop :: HistOp MCMem
histop@(HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda MCMem
lam),
             [VName]
histop_subhistograms,
             ([SubExp]
bucket, [SubExp]
vs')
             ) -> do
              let bucket' :: [TExp Int64]
bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
bucket
                  dest_shape' :: [TExp Int64]
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
                  bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds =
                    Slice (TExp Int64) -> [TExp Int64] -> TPrimExp Bool VName
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
                  vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam

              String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform updates" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
                TPrimExp Bool VName
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
                  [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam MCMem] -> ImpM MCMem HostEnv Multicore ())
-> [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
                  ShapeBase SubExp
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is' -> do
                    [(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
vs') (((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
res) ->
                      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
res [TExp Int64]
is'
                    HistOp MCMem
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
updateHisto HistOp MCMem
histop [VName]
histop_subhistograms ([TExp Int64]
bucket' [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is')

    -- Copy the task-local subhistograms to the global subhistograms,
    -- where they will be combined.
    [(VName, VName)]
-> ((VName, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
global_subhistograms) ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
local_subhistograms)) (((VName, VName) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((VName, VName) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
      \(VName
global, VName
local) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
global [TExp Int64
tid'] (VName -> SubExp
Var VName
local) []

  [Param]
free_params <- MCCode -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
  MCCode -> ImpM MCMem HostEnv Multicore ()
forall op rep r. Code op -> ImpM rep r op ()
emit (MCCode -> ImpM MCMem HostEnv Multicore ())
-> MCCode -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> MCCode
forall a. a -> Code a
Imp.Op (Multicore -> MCCode) -> Multicore -> MCCode
forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"seghist_stage_1" MCCode
body [Param]
free_params

  -- Perform a segmented reduction over the subhistograms
  [([PatElem LetDecMem], [VName], HistOp MCMem)]
-> (([PatElem LetDecMem], [VName], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElem LetDecMem]]
-> [[VName]]
-> [HistOp MCMem]
-> [([PatElem LetDecMem], [VName], HistOp MCMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LetDecMem]]
per_red_pes [[VName]]
global_subhistograms [HistOp MCMem]
histops) ((([PatElem LetDecMem], [VName], HistOp MCMem)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (([PatElem LetDecMem], [VName], HistOp MCMem)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \([PatElem LetDecMem]
red_pes, [VName]
hists, HistOp MCMem
op) -> do
    [VName]
bucket_ids <-
      Int
-> ImpM MCMem HostEnv Multicore VName
-> ImpM MCMem HostEnv Multicore [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
op)) (String -> ImpM MCMem HostEnv Multicore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id")
    VName
subhistogram_id <- String -> ImpM MCMem HostEnv Multicore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"subhistogram_id"

    let segred_space :: SegSpace
segred_space =
          VName -> [(VName, SubExp)] -> SegSpace
SegSpace (SegSpace -> VName
segFlat SegSpace
space) ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
            [(VName, SubExp)]
segment_dims
              [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
bucket_ids (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
op))
              [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_histos)]

        segred_op :: SegBinOp MCMem
segred_op = Commutativity
-> Lambda MCMem -> [SubExp] -> ShapeBase SubExp -> SegBinOp MCMem
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
Noncommutative (HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op) (HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp MCMem
op) (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp MCMem
op)

    MCCode
red_code <- ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall a b. (a -> b) -> a -> b
$ do
      TV Int32
nsubtasks <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int32)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"nsubtasks" PrimType
int32
      Multicore -> ImpM MCMem HostEnv Multicore ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> ImpM MCMem HostEnv Multicore ())
-> Multicore -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetNumTasks (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
nsubtasks
      MCCode -> ImpM MCMem HostEnv Multicore ()
forall op rep r. Code op -> ImpM rep r op ()
emit (MCCode -> ImpM MCMem HostEnv Multicore ())
-> (DoSegBody -> MulticoreGen MCCode)
-> DoSegBody
-> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Pat LetDecMem
-> SegSpace
-> [SegBinOp MCMem]
-> TV Int32
-> DoSegBody
-> MulticoreGen MCCode
compileSegRed' ([PatElem LetDecMem] -> Pat LetDecMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LetDecMem]
red_pes) SegSpace
segred_space [SegBinOp MCMem
segred_op] TV Int32
nsubtasks (DoSegBody -> ImpM MCMem HostEnv Multicore ())
-> DoSegBody -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> ImpM MCMem HostEnv Multicore ()
red_cont ->
        [(SubExp, [TExp Int64])] -> ImpM MCMem HostEnv Multicore ()
red_cont ([(SubExp, [TExp Int64])] -> ImpM MCMem HostEnv Multicore ())
-> [(SubExp, [TExp Int64])] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
          ((VName -> (SubExp, [TExp Int64]))
 -> [VName] -> [(SubExp, [TExp Int64])])
-> [VName]
-> (VName -> (SubExp, [TExp Int64]))
-> [(SubExp, [TExp Int64])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
hists ((VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])])
-> (VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
            ( VName -> SubExp
Var VName
subhisto,
              (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [TExp Int64]) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
                ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id] [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
bucket_ids
            )

    let ns_red :: [TExp Int64]
ns_red = ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
segred_space
        iterations :: TExp Int64
iterations = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
ns_red -- The segmented reduction is sequential over the inner most dimension
        scheduler_info :: SchedulerInfo
scheduler_info = Exp -> Scheduling -> SchedulerInfo
Imp.SchedulerInfo (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iterations) Scheduling
Imp.Static
        red_task :: ParallelTask
red_task = MCCode -> ParallelTask
Imp.ParallelTask MCCode
red_code
    [Param]
free_params_red <- MCCode -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
red_code
    MCCode -> ImpM MCMem HostEnv Multicore ()
forall op rep r. Code op -> ImpM rep r op ()
emit (MCCode -> ImpM MCMem HostEnv Multicore ())
-> MCCode -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> MCCode
forall a. a -> Code a
Imp.Op (Multicore -> MCCode) -> Multicore -> MCCode
forall a b. (a -> b) -> a -> b
$ String
-> [Param]
-> ParallelTask
-> Maybe ParallelTask
-> [Param]
-> SchedulerInfo
-> Multicore
Imp.SegOp String
"seghist_red" [Param]
free_params_red ParallelTask
red_task Maybe ParallelTask
forall a. Maybe a
Nothing [Param]
forall a. Monoid a => a
mempty SchedulerInfo
scheduler_info
  where
    segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

-- This implementation for a Segmented Hist only
-- parallelize over the segments,
-- where each segment is updated sequentially.
segmentedHist ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
segmentedHist :: Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen MCCode
segmentedHist Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
  MCCode -> ImpM MCMem HostEnv Multicore ()
forall op rep r. Code op -> ImpM rep r op ()
emit (MCCode -> ImpM MCMem HostEnv Multicore ())
-> MCCode -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> MCCode
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Segmented segHist" Maybe Exp
forall a. Maybe a
Nothing
  ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall a b. (a -> b) -> a -> b
$ do
    MCCode
body <- Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen MCCode
compileSegHistBody Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody
    [Param]
free_params <- MCCode -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
    MCCode -> ImpM MCMem HostEnv Multicore ()
forall op rep r. Code op -> ImpM rep r op ()
emit (MCCode -> ImpM MCMem HostEnv Multicore ())
-> MCCode -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> MCCode
forall a. a -> Code a
Imp.Op (Multicore -> MCCode) -> Multicore -> MCCode
forall a b. (a -> b) -> a -> b
$ String -> MCCode -> [Param] -> Multicore
Imp.ParLoop String
"segmented_hist" MCCode
body [Param]
free_params

compileSegHistBody ::
  Pat LetDecMem ->
  SegSpace ->
  [HistOp MCMem] ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
compileSegHistBody :: Pat LetDecMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen MCCode
compileSegHistBody Pat LetDecMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode)
-> ImpM MCMem HostEnv Multicore () -> MulticoreGen MCCode
forall a b. (a -> b) -> a -> b
$ do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns

  let num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
      map_pes :: [PatElem LetDecMem]
map_pes = Int -> [PatElem LetDecMem] -> [PatElem LetDecMem]
forall a. Int -> [a] -> [a]
drop Int
num_red_res ([PatElem LetDecMem] -> [PatElem LetDecMem])
-> [PatElem LetDecMem] -> [PatElem LetDecMem]
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat
      per_red_pes :: [[PatElem LetDecMem]]
per_red_pes = [HistOp MCMem] -> [PatElem LetDecMem] -> [[PatElem LetDecMem]]
forall rep a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops ([PatElem LetDecMem] -> [[PatElem LetDecMem]])
-> [PatElem LetDecMem] -> [[PatElem LetDecMem]]
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat

  VName -> PrimType -> ImpM MCMem HostEnv Multicore ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
  Multicore -> ImpM MCMem HostEnv Multicore ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> ImpM MCMem HostEnv Multicore ())
-> Multicore -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)

  String
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
generateChunkLoop String
"SegHist" ((TExp Int64 -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
idx -> do
    let inner_bound :: TExp Int64
inner_bound = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
ns_64
    String
-> TExp Int64
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
inner_bound ((TExp Int64 -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
      (VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ())
-> [VName] -> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
ns_64) TExp Int64
idx
      VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ([VName] -> VName
forall a. [a] -> a
last [VName]
is) TExp Int64
i

      Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
        let ([SubExp]
red_res, [SubExp]
map_res) =
              Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LetDecMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
        [([PatElem LetDecMem], HistOp MCMem, ([SubExp], [SubExp]))]
-> (([PatElem LetDecMem], HistOp MCMem, ([SubExp], [SubExp]))
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElem LetDecMem]]
-> [HistOp MCMem]
-> [([SubExp], [SubExp])]
-> [([PatElem LetDecMem], HistOp MCMem, ([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LetDecMem]]
per_red_pes [HistOp MCMem]
histops ([HistOp MCMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops [SubExp]
red_res)) ((([PatElem LetDecMem], HistOp MCMem, ([SubExp], [SubExp]))
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> (([PatElem LetDecMem], HistOp MCMem, ([SubExp], [SubExp]))
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
          \([PatElem LetDecMem]
red_pes, HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda MCMem
lam, ([SubExp]
bucket, [SubExp]
vs')) -> do
            let ([Param LetDecMem]
is_params, [Param LetDecMem]
vs_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
                bucket' :: [TExp Int64]
bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
bucket
                dest_shape' :: [TExp Int64]
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
                bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds = Slice (TExp Int64) -> [TExp Int64] -> TPrimExp Bool VName
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'

            String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
              [(PatElem LetDecMem, SubExp)]
-> ((PatElem LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> [SubExp] -> [(PatElem LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
map_pes [SubExp]
map_res) (((PatElem LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElem LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, SubExp
res) ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
res []

            String
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform updates" (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
              TPrimExp Bool VName
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ do
                [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam MCMem] -> ImpM MCMem HostEnv Multicore ())
-> [LParam MCMem] -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
                ShapeBase SubExp
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape (([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ([TExp Int64] -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
                  -- Index
                  [(PatElem LetDecMem, Param LetDecMem)]
-> ((PatElem LetDecMem, Param LetDecMem)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem]
-> [Param LetDecMem] -> [(PatElem LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes [Param LetDecMem]
is_params) (((PatElem LetDecMem, Param LetDecMem)
  -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElem LetDecMem, Param LetDecMem)
    -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, Param LetDecMem
p) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                      (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p)
                      []
                      (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
                      ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
bucket' [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
                  -- Value at index
                  [(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
vs') (((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((Param LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [TExp Int64]
vec_is
                  Names
-> Stms MCMem
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall rep. Body rep -> Stms rep
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda MCMem
lam) (ImpM MCMem HostEnv Multicore ()
 -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
                    [(PatElem LetDecMem, SubExp)]
-> ((PatElem LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> [SubExp] -> [(PatElem LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes ([SubExp] -> [(PatElem LetDecMem, SubExp)])
-> [SubExp] -> [(PatElem LetDecMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body MCMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body MCMem -> [SubExpRes]) -> Body MCMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda MCMem
lam) (((PatElem LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
 -> ImpM MCMem HostEnv Multicore ())
-> ((PatElem LetDecMem, SubExp) -> ImpM MCMem HostEnv Multicore ())
-> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$
                      \(PatElem LetDecMem
pe, SubExp
se) ->
                        VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM MCMem HostEnv Multicore ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                          (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
                          ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
bucket' [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
                          SubExp
se
                          []