{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
-- | Our compilation strategy for 'SegHist' is based around avoiding
-- bin conflicts.  We do this by splitting the input into chunks, and
-- for each chunk computing a single subhistogram.  Then we combine
-- the subhistograms using an ordinary segmented reduction ('SegRed').
--
-- There are some branches around to efficiently handle the case where
-- we use only a single subhistogram (because it's large), so that we
-- respect the asymptotics, and do not copy the destination array.
--
-- We also use a heuristic strategy for computing subhistograms in
-- local memory when possible.  Given:
--
-- H: total size of histograms in bytes, including any lock arrays.
--
-- G: group size
--
-- T: number of bytes of local memory each thread can be given without
-- impacting occupancy (determined experimentally, e.g. 32).
--
-- LMAX: maximum amount of local memory per workgroup (hard limit).
--
-- We wish to compute:
--
-- COOP: cooperation level (number of threads per subhistogram)
--
-- LH: number of local memory subhistograms
--
-- We do this as:
--
-- COOP = ceil(H / T)
-- LH = ceil((G*T)/H)
-- if COOP <= G && H <= LMAX then
--   use local memory
-- else
--   use global memory

module Futhark.CodeGen.ImpGen.Kernels.SegHist
  ( compileSegHist )
  where

import Control.Monad.Except
import Data.Maybe
import Data.List (foldl', genericLength, zip4, zip6)

import Prelude hiding (quot, rem)

import Futhark.MonadFreshNames
import Futhark.Representation.KernelsMem
import qualified Futhark.Representation.Mem.IxFun as IxFun
import Futhark.Pass.ExplicitAllocations()
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.SegRed (compileSegRed')
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)
import Futhark.Util (chunks, mapAccumLM, splitFromEnd, takeLast)
import Futhark.Construct (fullSliceNum)

i32Toi64 :: PrimExp v -> PrimExp v
i32Toi64 :: PrimExp v -> PrimExp v
i32Toi64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64)

data SubhistosInfo = SubhistosInfo { SubhistosInfo -> VName
subhistosArray :: VName
                                   , SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
                                   }

data SegHistSlug = SegHistSlug
                   { SegHistSlug -> HistOp KernelsMem
slugOp :: HistOp KernelsMem
                   , SegHistSlug -> VName
slugNumSubhistos :: VName
                   , SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo]
                   , SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate :: AtomicUpdate KernelsMem KernelEnv
                   }

histoSpaceUsage :: HistOp KernelsMem
                -> Imp.Count Imp.Bytes Imp.Exp
histoSpaceUsage :: HistOp KernelsMem -> Count Bytes Exp
histoSpaceUsage HistOp KernelsMem
op =
  [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> [Count Bytes Exp] -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$
  (Type -> Count Bytes Exp) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes Exp
typeSize (Type -> Count Bytes Exp)
-> (Type -> Type) -> Type -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
       (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op) (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
       (Type -> Shape -> Type
`arrayOfShape` HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)) ([Type] -> [Count Bytes Exp]) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$
  LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT KernelsMem -> [Type]) -> LambdaT KernelsMem -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op

-- | Figure out how much memory is needed per histogram, both
-- segmented and unsegmented,, and compute some other auxiliary
-- information.
computeHistoUsage :: SegSpace
                  -> HistOp KernelsMem
                  -> CallKernelGen (Imp.Count Imp.Bytes Imp.Exp,
                                    Imp.Count Imp.Bytes Imp.Exp,
                                    SegHistSlug)
computeHistoUsage :: SegSpace
-> HistOp KernelsMem
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
computeHistoUsage SegSpace
space HistOp KernelsMem
op = do
  let segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_segments :: Int
num_segments = [(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims

  -- Create names for the intermediate array memory blocks,
  -- memory block sizes, arrays, and number of subhistograms.
  VName
num_subhistos <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"num_subhistos" PrimType
int32
  [SubhistosInfo]
subhisto_infos <- [(VName, SubExp)]
-> ((VName, SubExp)
    -> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> ImpM KernelsMem HostEnv HostOp [SubhistosInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp KernelsMem
op) (HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp KernelsMem
op)) (((VName, SubExp) -> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
 -> ImpM KernelsMem HostEnv HostOp [SubhistosInfo])
-> ((VName, SubExp)
    -> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> ImpM KernelsMem HostEnv HostOp [SubhistosInfo]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) -> do
    Type
dest_t <- VName -> ImpM KernelsMem HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
dest
    MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
dest

    VName
subhistos_mem <-
      String -> Space -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Space -> ImpM lore r op VName
sDeclareMem (VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos_mem") (String -> Space
Space String
"device")

    let subhistos_shape :: Shape
subhistos_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++[VName -> SubExp
Var VName
num_subhistos]) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<>
                          Int -> Shape -> Shape
forall a. ArrayShape a => Int -> a -> a
stripDims Int
num_segments (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
        subhistos_membind :: MemBind
subhistos_membind = VName -> IxFun -> MemBind
ArrayIn VName
subhistos_mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
                            (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
    VName
subhistos <- String
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray (VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos")
                 (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t) Shape
subhistos_shape MemBind
subhistos_membind

    SubhistosInfo -> ImpM KernelsMem HostEnv HostOp SubhistosInfo
forall (m :: * -> *) a. Monad m => a -> m a
return (SubhistosInfo -> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> SubhistosInfo -> ImpM KernelsMem HostEnv HostOp SubhistosInfo
forall a b. (a -> b) -> a -> b
$ VName -> CallKernelGen () -> SubhistosInfo
SubhistosInfo VName
subhistos (CallKernelGen () -> SubhistosInfo)
-> CallKernelGen () -> SubhistosInfo
forall a b. (a -> b) -> a -> b
$ do
      let unitHistoCase :: CallKernelGen ()
unitHistoCase =
            Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLocation -> VName
memLocationName MemLocation
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
            String -> Space
Space String
"device"

          multiHistoCase :: CallKernelGen ()
multiHistoCase = do
            let num_elems :: Exp
num_elems = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(*) (VName -> PrimType -> Exp
Imp.var VName
num_subhistos PrimType
int32) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
                            (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dest_t

            let subhistos_mem_size :: Count Bytes Exp
subhistos_mem_size =
                  Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp) -> Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$
                  Count Bytes Exp -> Exp
forall u e. Count u e -> e
Imp.unCount (Exp -> Count Elements Exp
Imp.elements Exp
num_elems Count Elements Exp -> PrimType -> Count Bytes Exp
`Imp.withElemType` Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)

            VName -> Count Bytes Exp -> Space -> CallKernelGen ()
forall lore r op.
VName -> Count Bytes Exp -> Space -> ImpM lore r op ()
sAlloc_ VName
subhistos_mem Count Bytes Exp
subhistos_mem_size (Space -> CallKernelGen ()) -> Space -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"
            VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
            Type
subhistos_t <- VName -> ImpM KernelsMem HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
subhistos
            let slice :: Slice Exp
slice = [Exp] -> Slice Exp -> Slice Exp
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) (Slice Exp -> Slice Exp) -> Slice Exp -> Slice Exp
forall a b. (a -> b) -> a -> b
$
                        ((VName, SubExp) -> DimIndex Exp) -> [(VName, SubExp)] -> Slice Exp
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp -> DimIndex Exp
forall d. Num d => d -> d -> DimIndex d
unitSlice Exp
0 (Exp -> DimIndex Exp)
-> ((VName, SubExp) -> Exp) -> (VName, SubExp) -> DimIndex Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims Slice Exp -> Slice Exp -> Slice Exp
forall a. [a] -> [a] -> [a]
++
                        [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
0]
            VName -> Slice Exp -> SubExp -> CallKernelGen ()
forall lore r op. VName -> Slice Exp -> SubExp -> ImpM lore r op ()
sUpdate VName
subhistos Slice Exp
slice (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest

      Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (VName -> PrimType -> Exp
Imp.var VName
num_subhistos PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase

  let h :: Count Bytes Exp
h = HistOp KernelsMem -> Count Bytes Exp
histoSpaceUsage HistOp KernelsMem
op
      segmented_h :: Count Bytes Exp
segmented_h = Count Bytes Exp
h Count Bytes Exp -> Count Bytes Exp -> Count Bytes Exp
forall a. Num a => a -> a -> a
* [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> Count Bytes Exp) -> [SubExp] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp)
-> (SubExp -> Exp) -> SubExp -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Count Bytes Exp]) -> [SubExp] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM KernelsMem HostEnv HostOp HostEnv
-> ImpM KernelsMem HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv

  (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
forall (m :: * -> *) a. Monad m => a -> m a
return (Count Bytes Exp
h,
          Count Bytes Exp
segmented_h,
          HistOp KernelsMem
-> VName
-> [SubhistosInfo]
-> AtomicUpdate KernelsMem KernelEnv
-> SegHistSlug
SegHistSlug HistOp KernelsMem
op VName
num_subhistos [SubhistosInfo]
subhisto_infos (AtomicUpdate KernelsMem KernelEnv -> SegHistSlug)
-> AtomicUpdate KernelsMem KernelEnv -> SegHistSlug
forall a b. (a -> b) -> a -> b
$
          AtomicBinOp
-> LambdaT KernelsMem -> AtomicUpdate KernelsMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics (LambdaT KernelsMem -> AtomicUpdate KernelsMem KernelEnv)
-> LambdaT KernelsMem -> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)

prepareAtomicUpdateGlobal :: Maybe Locking -> [VName] -> SegHistSlug
                          -> CallKernelGen (Maybe Locking,
                                            [Imp.Exp] -> InKernelGen ())
prepareAtomicUpdateGlobal :: Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug =
  -- We need a separate lock array if the operators are not all of a
  -- particularly simple form that permits pure atomic operations.
  case (Maybe Locking
l, SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
    (Maybe Locking
_, AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
    (Maybe Locking
_, AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
    (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)
    (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> do
      -- The number of locks used here is too low, but since we are
      -- currently forced to inline a huge list, I'm keeping it down
      -- for now.  Some quick experiments suggested that it has little
      -- impact anyway (maybe the locking case is just too slow).
      --
      -- A fun solution would also be to use a simple hashing
      -- algorithm to ensure good distribution of locks.
      let num_locks :: Int
num_locks = Int
100151
          dims :: [Exp]
dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$
                 Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++
                 [ VName -> SubExp
Var (SegHistSlug -> VName
slugNumSubhistos SegHistSlug
slug)
                 , HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)]
      VName
locks <-
        String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"hist_locks" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
        Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
      let l' :: Locking
l' = VName -> Exp -> Exp -> Exp -> ([Exp] -> [Exp]) -> Locking
Locking VName
locks Exp
0 Exp
1 Exp
0 (Exp -> [Exp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> [Exp]) -> ([Exp] -> Exp) -> [Exp] -> [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (Exp -> Exp) -> ([Exp] -> Exp) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> [Exp] -> Exp
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [Exp]
dims)
      (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)

-- | Some kernel bodies are not safe (or efficient) to execute
-- multiple times.
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
(Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool) -> Eq Passage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c== :: Passage -> Passage -> Bool
Eq, Eq Passage
Eq Passage
-> (Passage -> Passage -> Ordering)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Passage)
-> (Passage -> Passage -> Passage)
-> Ord Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmax :: Passage -> Passage -> Passage
>= :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c< :: Passage -> Passage -> Bool
compare :: Passage -> Passage -> Ordering
$ccompare :: Passage -> Passage -> Ordering
$cp1Ord :: Eq Passage
Ord)

bodyPassage :: KernelBody KernelsMem -> Passage
bodyPassage :: KernelBody KernelsMem -> Passage
bodyPassage KernelBody KernelsMem
kbody
  | Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== KernelBody (Aliases KernelsMem) -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody (KernelBody KernelsMem -> KernelBody (Aliases KernelsMem)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody KernelBody KernelsMem
kbody) =
      Passage
MayBeMultiPass
  | Bool
otherwise =
      Passage
MustBeSinglePass

prepareIntermediateArraysGlobal :: Passage -> Imp.Exp -> Imp.Exp -> [SegHistSlug]
                                -> CallKernelGen
                                   (Imp.Exp,
                                    [[Imp.Exp] -> InKernelGen ()])
prepareIntermediateArraysGlobal :: Passage
-> Exp
-> Exp
-> [SegHistSlug]
-> CallKernelGen (Exp, [[Exp] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage Exp
hist_T Exp
hist_N [SegHistSlug]
slugs = do
  -- The paper formulae assume there is only one histogram, but in our
  -- implementation there can be multiple that have been horisontally
  -- fused.  We do a bit of trickery with summings and averages to
  -- pretend there is really only one.  For the case of a single
  -- histogram, the actual calculations should be the same as in the
  -- paper.

  -- The sum of all Hs.
  Exp
hist_H <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> ([Exp] -> Exp) -> [Exp] -> ImpM KernelsMem HostEnv HostOp Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Exp] -> ImpM KernelsMem HostEnv HostOp Exp)
-> ImpM KernelsMem HostEnv HostOp [Exp]
-> ImpM KernelsMem HostEnv HostOp Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SegHistSlug -> ImpM KernelsMem HostEnv HostOp Exp)
-> [SegHistSlug] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> (SegHistSlug -> SubExp)
-> SegHistSlug
-> ImpM KernelsMem HostEnv HostOp Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs

  Exp
hist_RF <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_RF" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
    [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (Exp -> Exp) -> (SegHistSlug -> Exp) -> SegHistSlug -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> (SegHistSlug -> SubExp) -> SegHistSlug -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histRaceFactor (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs)
    Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 ([SegHistSlug] -> Exp
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs)

  Exp
hist_el_size <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_el_size" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Exp
slugElAvgSize [SegHistSlug]
slugs

  Exp
hist_C_max <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_C_max" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
    BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMin FloatType
Float64) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_H Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_k_ct_min

  Exp
hist_M_min <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_M_min" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
    BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_C_max

  -- Querying L2 cache size is not reliable.  Instead we provide a
  -- tunable knob with a hopefully sane default.
  let hist_L2_def :: Int32
hist_L2_def = Int32
4 Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
1024 Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
1024
  VName
hist_L2 <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"L2_size" PrimType
int32
  Maybe Name
entry <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  -- Equivalent to F_L2*L2 in paper.
  HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
hist_L2
    (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
entry (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (VName -> String
forall a. Pretty a => a -> String
pretty VName
hist_L2)) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
    Name -> Int32 -> SizeClass
Imp.SizeBespoke (String -> Name
nameFromString String
"L2_for_histogram") Int32
hist_L2_def

  let hist_L2_ln_sz :: Exp
hist_L2_ln_sz = Exp
16Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
4 -- L2 cache line size approximation

  Exp
hist_RACE_exp <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_RACE_exp" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
    BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMax FloatType
Float64) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
    (Exp
hist_k_RF Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_RF) Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/
    (Exp
hist_L2_ln_sz Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_el_size)

  VName
hist_S <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"hist_S" PrimType
int32

  -- For sparse histograms (H exceeds N) we only want a single chunk.
  Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
hist_N Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
hist_H)
    (VName
hist_S VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    VName
hist_S VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<--
    case Passage
passage of
      Passage
MayBeMultiPass ->
        (Exp
hist_M_min Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_el_size) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp`
        Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 (Exp
hist_F_L2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (VName -> Exp
Imp.vi32 VName
hist_L2) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_RACE_exp)
      Passage
MustBeSinglePass ->
        Exp
1

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race expansion factor (RACE^exp)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_RACE_exp
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_S

  [[Exp] -> InKernelGen ()]
histograms <- (Maybe Locking, [[Exp] -> InKernelGen ()])
-> [[Exp] -> InKernelGen ()]
forall a b. (a, b) -> b
snd ((Maybe Locking, [[Exp] -> InKernelGen ()])
 -> [[Exp] -> InKernelGen ()])
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     (Maybe Locking, [[Exp] -> InKernelGen ()])
-> ImpM KernelsMem HostEnv HostOp [[Exp] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe Locking
 -> SegHistSlug
 -> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ()))
-> Maybe Locking
-> [SegHistSlug]
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     (Maybe Locking, [[Exp] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (Exp
-> Exp
-> Exp
-> Exp
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
onOp (VName -> Exp
Imp.vi32 VName
hist_L2) Exp
hist_M_min (VName -> Exp
Imp.vi32 VName
hist_S) Exp
hist_RACE_exp) Maybe Locking
forall a. Maybe a
Nothing [SegHistSlug]
slugs

  (Exp, [[Exp] -> InKernelGen ()])
-> CallKernelGen (Exp, [[Exp] -> InKernelGen ()])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Exp
Imp.vi32 VName
hist_S, [[Exp] -> InKernelGen ()]
histograms)
  where
    hist_k_ct_min :: Exp
hist_k_ct_min = Exp
2 -- Chosen experimentally
    hist_k_RF :: Exp
hist_k_RF = Exp
0.75 -- Chosen experimentally
    hist_F_L2 :: Exp
hist_F_L2 = Exp
0.4 -- Chosen experimentally

    r64 :: PrimExp v -> PrimExp v
r64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64)
    t64 :: PrimExp v -> PrimExp v
t64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int32)

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElAvgSize :: SegHistSlug -> Exp
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp KernelsMem
op VName
_ [SubhistosInfo]
_ AtomicUpdate KernelsMem KernelEnv
do_op) =
      case AtomicUpdate KernelsMem KernelEnv
do_op of
        AtomicLocking{} ->
          SegHistSlug -> Exp
slugElSize SegHistSlug
slug Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` (Exp
1Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+[Type] -> Exp
forall i a. Num i => [a] -> i
genericLength (LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)))
        AtomicUpdate KernelsMem KernelEnv
_ ->
          SegHistSlug -> Exp
slugElSize SegHistSlug
slug Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` [Type] -> Exp
forall i a. Num i => [a] -> i
genericLength (LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op))

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElSize :: SegHistSlug -> Exp
slugElSize (SegHistSlug HistOp KernelsMem
op VName
_ [SubhistosInfo]
_ AtomicUpdate KernelsMem KernelEnv
do_op) =
      case AtomicUpdate KernelsMem KernelEnv
do_op of
        AtomicLocking{} ->
          Count Bytes Exp -> Exp
forall u e. Count u e -> e
unCount
          ([Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> [Count Bytes Exp] -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ (Type -> Count Bytes Exp) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes Exp
typeSize (Type -> Count Bytes Exp)
-> (Type -> Type) -> Type -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)) ([Type] -> [Count Bytes Exp]) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$
           PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op))
        AtomicUpdate KernelsMem KernelEnv
_ ->
          Count Bytes Exp -> Exp
forall u e. Count u e -> e
unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> [Count Bytes Exp] -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$
          (Type -> Count Bytes Exp) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes Exp
typeSize (Type -> Count Bytes Exp)
-> (Type -> Type) -> Type -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)) ([Type] -> [Count Bytes Exp]) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$
          LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)

    onOp :: Exp
-> Exp
-> Exp
-> Exp
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
onOp Exp
hist_L2 Exp
hist_M_min Exp
hist_S Exp
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
      let SegHistSlug HistOp KernelsMem
op VName
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate KernelsMem KernelEnv
do_op = SegHistSlug
slug
      Exp
hist_H <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op

      Exp
hist_H_chk <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H_chk" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                    Exp
hist_H Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_S

      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Chunk size (H_chk)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_H_chk

      Exp
hist_k_max <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_k_max" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
        BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMin FloatType
Float64)
        (Exp
hist_F_L2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_L2 Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (SegHistSlug -> Exp
slugElSize SegHistSlug
slug)) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_RACE_exp)
        (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_N)
        Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T

      Exp
hist_u <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_u" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                case AtomicUpdate KernelsMem KernelEnv
do_op of
                  AtomicPrim{} -> Exp
2
                  AtomicUpdate KernelsMem KernelEnv
_            -> Exp
1

      Exp
hist_C <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_C" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMin FloatType
Float64) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (Exp
hist_u Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk) Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_k_max

      -- Number of subhistograms per result histogram.
      Exp
hist_M <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_M" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
        case SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
          AtomicPrim{} -> Exp
1
          AtomicUpdate KernelsMem KernelEnv
_ -> BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
hist_M_min (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
               Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_C

      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Elements/thread in L2 cache (k_max)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_k_max
      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_M
      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_C

      -- num_subhistos is the variable we use to communicate back.
      VName
num_subhistos VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
hist_M

      -- Initialise sub-histograms.
      --
      -- If hist_M is 1, then we just reuse the original
      -- destination.  The idea is to avoid a copy if we are writing a
      -- small number of values into a very large prior histogram.
      [VName]
dests <- [(VName, SubhistosInfo)]
-> ((VName, SubhistosInfo) -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubhistosInfo] -> [(VName, SubhistosInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp KernelsMem
op) [SubhistosInfo]
subhisto_info) (((VName, SubhistosInfo) -> ImpM KernelsMem HostEnv HostOp VName)
 -> ImpM KernelsMem HostEnv HostOp [VName])
-> ((VName, SubhistosInfo) -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubhistosInfo
info) -> do
        MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
dest

        VName
sub_mem <- (MemLocation -> VName)
-> ImpM KernelsMem HostEnv HostOp MemLocation
-> ImpM KernelsMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemLocation -> VName
memLocationName (ImpM KernelsMem HostEnv HostOp MemLocation
 -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp MemLocation
-> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
                   ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                   VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (SubhistosInfo -> VName
subhistosArray SubhistosInfo
info)

        let unitHistoCase :: CallKernelGen ()
unitHistoCase =
              Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
              VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLocation -> VName
memLocationName MemLocation
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
              String -> Space
Space String
"device"

            multiHistoCase :: CallKernelGen ()
multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info

        Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
hist_M Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase

        VName -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM KernelsMem HostEnv HostOp VName)
-> VName -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info

      (Maybe Locking
l', [Exp] -> InKernelGen ()
do_op') <- Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug

      (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l', [Exp] -> InKernelGen ()
do_op')

histKernelGlobalPass :: [PatElem KernelsMem]
                     -> Count NumGroups Imp.Exp
                     -> Count GroupSize Imp.Exp
                     -> SegSpace
                     -> [SegHistSlug]
                     -> KernelBody KernelsMem
                     -> [[Imp.Exp] -> InKernelGen ()]
                     -> Imp.Exp -> Imp.Exp
                     -> CallKernelGen ()
histKernelGlobalPass :: [PatElem KernelsMem]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> [[Exp] -> InKernelGen ()]
-> Exp
-> Exp
-> CallKernelGen ()
histKernelGlobalPass [PatElem KernelsMem]
map_pes Count NumGroups Exp
num_groups Count GroupSize Exp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody [[Exp] -> InKernelGen ()]
histograms Exp
hist_S Exp
chk_i = do

  let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      space_sizes_64 :: [Exp]
space_sizes_64 = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32Toi64 (Exp -> Exp) -> (SubExp -> Exp) -> SubExp -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
space_sizes
      total_w_64 :: Exp
total_w_64 = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
space_sizes_64

  [Exp]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
 -> ImpM KernelsMem HostEnv HostOp [Exp])
-> (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> ImpM KernelsMem HostEnv HostOp [Exp]
forall a b. (a -> b) -> a -> b
$ \SubExp
w -> do
    Exp
w' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
w
    String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H_chk" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
w' Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_S

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_global" Count NumGroups Exp
num_groups Count GroupSize Exp
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

    -- Compute subhistogram index for each thread, per histogram.
    [Exp]
subhisto_inds <- [SegHistSlug]
-> (SegHistSlug -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegHistSlug]
slugs ((SegHistSlug -> ImpM KernelsMem KernelEnv KernelOp Exp)
 -> ImpM KernelsMem KernelEnv KernelOp [Exp])
-> (SegHistSlug -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug ->
      String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"subhisto_ind" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
      KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot`
      (KernelConstants -> Exp
kernelNumThreads KernelConstants
constants Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` VName -> Exp
Imp.vi32 (SegHistSlug -> VName
slugNumSubhistos SegHistSlug
slug))

    -- Loop over flat offsets into the input and output.  The
    -- calculation is done with 64-bit integers to avoid overflow,
    -- but the final unflattened segment indexes are 32 bit.
    let gtid :: Exp
gtid = Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32Toi64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
        num_threads :: Exp
num_threads = Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32Toi64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelNumThreads KernelConstants
constants
    Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop Exp
gtid Exp
num_threads Exp
total_w_64 ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
offset -> do

      -- Construct segment indices.
      let setIndex :: VName -> Exp -> ImpM lore r op ()
setIndex VName
v Exp
e = do VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
v PrimType
int32
                            VName
v VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
e
      (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
setIndex [VName]
space_is ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        (Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int64 IntType
Int32)) ([Exp] -> [Exp]) -> [Exp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
space_sizes_64 Exp
offset

      -- We execute the bucket function once and update each histogram serially.
      -- We apply the bucket function if j=offset+ltid is less than
      -- num_elements.  This also involves writing to the mapout
      -- arrays.
      let input_in_bounds :: Exp
input_in_bounds = Exp
offset Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
total_w_64

      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
input_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LetAttrMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem KernelsMem]
[PatElemT LetAttrMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody

        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(PatElemT LetAttrMem, KernelResult)]
-> ((PatElemT LetAttrMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetAttrMem]
-> [KernelResult] -> [(PatElemT LetAttrMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetAttrMem]
map_pes [KernelResult]
map_res) (((PatElemT LetAttrMem, KernelResult) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT LetAttrMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetAttrMem
pe, KernelResult
res) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetAttrMem -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT LetAttrMem
pe)
          (((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Exp
Imp.vi32 (VName -> Exp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
          (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

        let ([KernelResult]
buckets, [KernelResult]
vs) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [KernelResult]
red_res
            perOp :: [KernelResult] -> [[KernelResult]]
perOp = [Int] -> [KernelResult] -> [[KernelResult]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [KernelResult] -> [[KernelResult]])
-> [Int] -> [KernelResult] -> [[KernelResult]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest (HistOp KernelsMem -> [VName])
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs

        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
  [KernelResult], Exp, Exp)]
-> ((HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
     [KernelResult], Exp, Exp)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp KernelsMem]
-> [[Exp] -> InKernelGen ()]
-> [KernelResult]
-> [[KernelResult]]
-> [Exp]
-> [Exp]
-> [(HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
     [KernelResult], Exp, Exp)]
forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 ((SegHistSlug -> HistOp KernelsMem)
-> [SegHistSlug] -> [HistOp KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp KernelsMem
slugOp [SegHistSlug]
slugs) [[Exp] -> InKernelGen ()]
histograms [KernelResult]
buckets ([KernelResult] -> [[KernelResult]]
perOp [KernelResult]
vs) [Exp]
subhisto_inds [Exp]
hist_H_chks) (((HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
   [KernelResult], Exp, Exp)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
     [KernelResult], Exp, Exp)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          \(HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT KernelsMem
lam,
            [Exp] -> InKernelGen ()
do_op, KernelResult
bucket, [KernelResult]
vs', Exp
subhisto_ind, Exp
hist_H_chk) -> do

            let chk_beg :: Exp
chk_beg = Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk
                bucket' :: Exp
bucket' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> SubExp -> Exp
forall a b. (a -> b) -> a -> b
$ KernelResult -> SubExp
kernelResultSubExp KernelResult
bucket
                dest_w' :: Exp
dest_w' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
dest_w
                bucket_in_bounds :: Exp
bucket_in_bounds = Exp
chk_beg Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                                   Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. (Exp
chk_beg Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
hist_H_chk) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                                   Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
dest_w'
                vs_params :: [Param LetAttrMem]
vs_params = Int -> [Param LetAttrMem] -> [Param LetAttrMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
vs') ([Param LetAttrMem] -> [Param LetAttrMem])
-> [Param LetAttrMem] -> [Param LetAttrMem]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam

            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              let bucket_is :: [Exp]
bucket_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is) [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++
                              [Exp
subhisto_ind, Exp
bucket']
              [LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
              Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is -> do
                [(Param LetAttrMem, KernelResult)]
-> ((Param LetAttrMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetAttrMem]
-> [KernelResult] -> [(Param LetAttrMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetAttrMem]
vs_params [KernelResult]
vs') (((Param LetAttrMem, KernelResult) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetAttrMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetAttrMem
p, KernelResult
res) ->
                  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetAttrMem -> VName
forall attr. Param attr -> VName
paramName Param LetAttrMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) [Exp]
is
                [Exp] -> InKernelGen ()
do_op ([Exp]
bucket_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)


histKernelGlobal :: [PatElem KernelsMem]
                 -> Count NumGroups SubExp -> Count GroupSize SubExp
                 -> SegSpace
                 -> [SegHistSlug]
                 -> KernelBody KernelsMem
                 -> CallKernelGen ()
histKernelGlobal :: [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelGlobal [PatElem KernelsMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody = do
  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
  let ([VName]
_space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_threads :: Exp
num_threads = Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using global memory" Maybe Exp
forall a. Maybe a
Nothing

  (Exp
hist_S, [[Exp] -> InKernelGen ()]
histograms) <-
    Passage
-> Exp
-> Exp
-> [SegHistSlug]
-> CallKernelGen (Exp, [[Exp] -> InKernelGen ()])
prepareIntermediateArraysGlobal (KernelBody KernelsMem -> Passage
bodyPassage KernelBody KernelsMem
kbody)
    Exp
num_threads (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> SubExp -> Exp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes) [SegHistSlug]
slugs

  String -> Exp -> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"chk_i" Exp
hist_S ((Exp -> CallKernelGen ()) -> CallKernelGen ())
-> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
chk_i ->
    [PatElem KernelsMem]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> [[Exp] -> InKernelGen ()]
-> Exp
-> Exp
-> CallKernelGen ()
histKernelGlobalPass [PatElem KernelsMem]
map_pes Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody
    [[Exp] -> InKernelGen ()]
histograms Exp
hist_S Exp
chk_i

type InitLocalHistograms = [([VName],
                              SubExp ->
                              InKernelGen ([VName],
                                            [Imp.Exp] -> InKernelGen ()))]

prepareIntermediateArraysLocal :: VName
                               -> Count NumGroups Imp.Exp
                               -> SegSpace -> [SegHistSlug]
                               -> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: VName
-> Count NumGroups Exp
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal VName
num_subhistos_per_group Count NumGroups Exp
groups_per_segment SegSpace
space [SegHistSlug]
slugs = do
  Exp
num_segments <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"num_segments" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                  [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  (SegHistSlug
 -> ImpM
      KernelsMem
      HostEnv
      HostOp
      ([VName],
       SubExp
       -> ImpM
            KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())))
-> [SegHistSlug] -> CallKernelGen InitLocalHistograms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Exp
-> SegHistSlug
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
onOp Exp
num_segments) [SegHistSlug]
slugs
  where
    onOp :: Exp
-> SegHistSlug
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
onOp Exp
num_segments (SegHistSlug HistOp KernelsMem
op VName
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate KernelsMem KernelEnv
do_op) = do

      VName
num_subhistos VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_segments

      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of subhistograms in global memory" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
        Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
num_subhistos

      SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
mk_op <-
        case AtomicUpdate KernelsMem KernelEnv
do_op of
          AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
f -> (SubExp
 -> ImpM
      KernelsMem
      KernelEnv
      KernelOp
      (DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           KernelsMem
           KernelEnv
           KernelOp
           (DoAtomicUpdate KernelsMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
  -> ImpM
       KernelsMem
       KernelEnv
       KernelOp
       (DoAtomicUpdate KernelsMem KernelEnv))
 -> ImpM
      KernelsMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            KernelsMem
            KernelEnv
            KernelOp
            (DoAtomicUpdate KernelsMem KernelEnv)))
-> (SubExp
    -> ImpM
         KernelsMem
         KernelEnv
         KernelOp
         (DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           KernelsMem
           KernelEnv
           KernelOp
           (DoAtomicUpdate KernelsMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM
  KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. a -> b -> a
const (ImpM
   KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
 -> SubExp
 -> ImpM
      KernelsMem
      KernelEnv
      KernelOp
      (DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate KernelsMem KernelEnv
f
          AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
f -> (SubExp
 -> ImpM
      KernelsMem
      KernelEnv
      KernelOp
      (DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           KernelsMem
           KernelEnv
           KernelOp
           (DoAtomicUpdate KernelsMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
  -> ImpM
       KernelsMem
       KernelEnv
       KernelOp
       (DoAtomicUpdate KernelsMem KernelEnv))
 -> ImpM
      KernelsMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            KernelsMem
            KernelEnv
            KernelOp
            (DoAtomicUpdate KernelsMem KernelEnv)))
-> (SubExp
    -> ImpM
         KernelsMem
         KernelEnv
         KernelOp
         (DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           KernelsMem
           KernelEnv
           KernelOp
           (DoAtomicUpdate KernelsMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM
  KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. a -> b -> a
const (ImpM
   KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
 -> SubExp
 -> ImpM
      KernelsMem
      KernelEnv
      KernelOp
      (DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate KernelsMem KernelEnv
f
          AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f -> (SubExp
 -> ImpM
      KernelsMem
      KernelEnv
      KernelOp
      (DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           KernelsMem
           KernelEnv
           KernelOp
           (DoAtomicUpdate KernelsMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
  -> ImpM
       KernelsMem
       KernelEnv
       KernelOp
       (DoAtomicUpdate KernelsMem KernelEnv))
 -> ImpM
      KernelsMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            KernelsMem
            KernelEnv
            KernelOp
            (DoAtomicUpdate KernelsMem KernelEnv)))
-> (SubExp
    -> ImpM
         KernelsMem
         KernelEnv
         KernelOp
         (DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           KernelsMem
           KernelEnv
           KernelOp
           (DoAtomicUpdate KernelsMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
            let lock_shape :: Shape
lock_shape =
                  [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
num_subhistos_per_group SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
                  Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++
                  [SubExp
hist_H_chk]

            [Exp]
dims <- (SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp])
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
lock_shape

            VName
locks <- String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"locks" PrimType
int32 Shape
lock_shape (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
dims (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
locks [Exp]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

            DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return (DoAtomicUpdate KernelsMem KernelEnv
 -> ImpM
      KernelsMem
      KernelEnv
      KernelOp
      (DoAtomicUpdate KernelsMem KernelEnv))
-> DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ Locking -> DoAtomicUpdate KernelsMem KernelEnv
f (Locking -> DoAtomicUpdate KernelsMem KernelEnv)
-> Locking -> DoAtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Exp -> Exp -> ([Exp] -> [Exp]) -> Locking
Locking VName
locks Exp
0 Exp
1 Exp
0 [Exp] -> [Exp]
forall a. a -> a
id

      -- Initialise local-memory sub-histograms.  These are
      -- represented as two-dimensional arrays.
      let init_local_subhistos :: SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
init_local_subhistos SubExp
hist_H_chk = do
            [VName]
local_subhistos <-
              [Type]
-> (Type -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp KernelsMem -> [Type]
forall lore. HistOp lore -> [Type]
histType HistOp KernelsMem
op) ((Type -> ImpM KernelsMem KernelEnv KernelOp VName)
 -> ImpM KernelsMem KernelEnv KernelOp [VName])
-> (Type -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
                let sub_local_shape :: Shape
sub_local_shape =
                      [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_subhistos_per_group] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<>
                      (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t Shape -> SubExp -> Shape
forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` SubExp
hist_H_chk)
                String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"subhistogram_local"
                  (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) Shape
sub_local_shape (String -> Space
Space String
"local")

            DoAtomicUpdate KernelsMem KernelEnv
do_op' <- SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
mk_op SubExp
hist_H_chk

            ([VName], [Exp] -> InKernelGen ())
-> ImpM
     KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
local_subhistos, DoAtomicUpdate KernelsMem KernelEnv
do_op' (String -> Space
Space String
"local") [VName]
local_subhistos)

      -- Initialise global-memory sub-histograms.
      [VName]
glob_subhistos <- [SubhistosInfo]
-> (SubhistosInfo -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubhistosInfo]
subhisto_info ((SubhistosInfo -> ImpM KernelsMem HostEnv HostOp VName)
 -> ImpM KernelsMem HostEnv HostOp [VName])
-> (SubhistosInfo -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubhistosInfo
info -> do
        SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
        VName -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM KernelsMem HostEnv HostOp VName)
-> VName -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info

      ([VName],
 SubExp
 -> ImpM
      KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
glob_subhistos, SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
init_local_subhistos)

histKernelLocalPass :: VName -> Count NumGroups Imp.Exp
                    -> [PatElem KernelsMem]
                    -> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
                    -> SegSpace
                    -> [SegHistSlug]
                    -> KernelBody KernelsMem
                    -> InitLocalHistograms -> Imp.Exp -> Imp.Exp
                    -> CallKernelGen ()
histKernelLocalPass :: VName
-> Count NumGroups Exp
-> [PatElem KernelsMem]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> InitLocalHistograms
-> Exp
-> Exp
-> CallKernelGen ()
histKernelLocalPass VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment [PatElem KernelsMem]
map_pes Count NumGroups Exp
num_groups Count GroupSize Exp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody
                    InitLocalHistograms
init_histograms Exp
hist_S Exp
chk_i = do
  let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      segment_is :: [VName]
segment_is = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is
      segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
      (VName
i_in_segment, SubExp
segment_size) = [(VName, SubExp)] -> (VName, SubExp)
forall a. [a] -> a
last ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_subhistos_per_group :: Exp
num_subhistos_per_group = VName -> PrimType -> Exp
Imp.var VName
num_subhistos_per_group_var PrimType
int32

  Exp
segment_size' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
segment_size

  Exp
num_segments <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"num_segments" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                  [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
segment_dims

  [VName]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM KernelsMem HostEnv HostOp VName)
 -> ImpM KernelsMem HostEnv HostOp [VName])
-> (SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubExp
w -> do
    Exp
w' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
w
    String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"hist_H_chk" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Exp
w' Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_S

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_local" Count NumGroups Exp
num_groups Count GroupSize Exp
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_segments) ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
group_id_var -> do

    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

    let group_id :: Exp
group_id = VName -> Exp
Imp.vi32 VName
group_id_var

    Exp
flat_segment_id <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"flat_segment_id" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_id Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment
    Exp
gid_in_segment <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"gid_in_segment" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_id Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment
    -- This pgtid is kind of a "virtualised physical" gtid - not the
    -- same thing as the gtid used for the SegHist itself.
    Exp
pgtid_in_segment <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"pgtid_in_segment" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
      Exp
gid_in_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
    Exp
threads_per_segment <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"threads_per_segment" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
      Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants

    -- Set segment indices.
    (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
segment_is ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
segment_dims) Exp
flat_segment_id

    [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
histograms <- [(([VName],
   SubExp
   -> ImpM
        KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
  VName)]
-> ((([VName],
      SubExp
      -> ImpM
           KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
     VName)
    -> ImpM
         KernelsMem
         KernelEnv
         KernelOp
         ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (InitLocalHistograms
-> [VName]
-> [(([VName],
      SubExp
      -> ImpM
           KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
     VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip InitLocalHistograms
init_histograms [VName]
hist_H_chks) (((([VName],
    SubExp
    -> ImpM
         KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
   VName)
  -> ImpM
       KernelsMem
       KernelEnv
       KernelOp
       ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
 -> ImpM
      KernelsMem
      KernelEnv
      KernelOp
      [([(VName, VName)], VName, [Exp] -> InKernelGen ())])
-> ((([VName],
      SubExp
      -> ImpM
           KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
     VName)
    -> ImpM
         KernelsMem
         KernelEnv
         KernelOp
         ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
forall a b. (a -> b) -> a -> b
$
                  \(([VName]
glob_subhistos, SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
init_local_subhistos), VName
hist_H_chk) -> do
      ([VName]
local_subhistos, [Exp] -> InKernelGen ()
do_op) <- SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
init_local_subhistos (SubExp
 -> ImpM
      KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
-> SubExp
-> ImpM
     KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
hist_H_chk
      ([(VName, VName)], VName, [Exp] -> InKernelGen ())
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     ([(VName, VName)], VName, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
glob_subhistos [VName]
local_subhistos, VName
hist_H_chk, [Exp] -> InKernelGen ()
do_op)

    -- Find index of local subhistograms updated by this thread.  We
    -- try to ensure, as much as possible, that threads in the same
    -- warp use different subhistograms, to avoid conflicts.
    Exp
thread_local_subhisto_i <-
      String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"thread_local_subhisto_i" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
      KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
num_subhistos_per_group

    let onSlugs :: (SegHistSlug
 -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
onSlugs SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ()
f = [(SegHistSlug, ([(VName, VName)], VName, [Exp] -> InKernelGen ()))]
-> ((SegHistSlug,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
-> [(SegHistSlug,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()))]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegHistSlug]
slugs [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
histograms) (((SegHistSlug, ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegHistSlug,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, ([(VName, VName)]
dests, VName
hist_H_chk, [Exp] -> InKernelGen ()
_)) -> do
          let histo_dims :: [Exp]
histo_dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
hist_H_chk SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
                           Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug))
          Exp
histo_size <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"histo_size" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
histo_dims
          SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (VName -> Exp
Imp.vi32 VName
hist_H_chk) [Exp]
histo_dims Exp
histo_size

    let onAllHistograms :: (VName
 -> VName
 -> HistOp KernelsMem
 -> SubExp
 -> Exp
 -> Exp
 -> [Exp]
 -> [Exp]
 -> InKernelGen ())
-> InKernelGen ()
onAllHistograms VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ()
f =
          (SegHistSlug
 -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
  -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
 -> InKernelGen ())
-> (SegHistSlug
    -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests Exp
hist_H_chk [Exp]
histo_dims Exp
histo_size -> do
            let group_hists_size :: Exp
group_hists_size = Exp
num_subhistos_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
histo_size
            Exp
init_per_thread <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"init_per_thread" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                               Exp
group_hists_size Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp`
                               KernelConstants -> Exp
kernelGroupSize KernelConstants
constants

            [((VName, VName), SubExp)]
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, VName)] -> [SubExp] -> [((VName, VName), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral (HistOp KernelsMem -> [SubExp]) -> HistOp KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)) ((((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              \((VName
dest_global, VName
dest_local), SubExp
ne) ->
                String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"local_i" Exp
init_per_thread ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
                  Exp
j <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"j" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                       Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
                       KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
                  Exp
j_offset <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"j_offset" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                              Exp
num_subhistos_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
histo_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
gid_in_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
j

                  Exp
local_subhisto_i <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"local_subhisto_i" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
j Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
histo_size
                  let local_bucket_is :: [Exp]
local_bucket_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
histo_dims (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ Exp
j Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
histo_size
                      global_bucket_is :: [Exp]
global_bucket_is = [Exp] -> Exp
forall a. [a] -> a
head [Exp]
local_bucket_is Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:
                                         [Exp] -> [Exp]
forall a. [a] -> [a]
tail [Exp]
local_bucket_is
                  Exp
global_subhisto_i <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"global_subhisto_i" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
j_offset Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
histo_size

                  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
j Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
group_hists_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                    VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ()
f VName
dest_local VName
dest_global (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug) SubExp
ne
                    Exp
local_subhisto_i Exp
global_subhisto_i
                    [Exp]
local_bucket_is [Exp]
global_bucket_is

    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"initialize histograms in local memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      (VName
 -> VName
 -> HistOp KernelsMem
 -> SubExp
 -> Exp
 -> Exp
 -> [Exp]
 -> [Exp]
 -> InKernelGen ())
-> InKernelGen ()
onAllHistograms ((VName
  -> VName
  -> HistOp KernelsMem
  -> SubExp
  -> Exp
  -> Exp
  -> [Exp]
  -> [Exp]
  -> InKernelGen ())
 -> InKernelGen ())
-> (VName
    -> VName
    -> HistOp KernelsMem
    -> SubExp
    -> Exp
    -> Exp
    -> [Exp]
    -> [Exp]
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
dest_local VName
dest_global HistOp KernelsMem
op SubExp
ne Exp
local_subhisto_i Exp
global_subhisto_i [Exp]
local_bucket_is [Exp]
global_bucket_is ->
      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"First subhistogram is initialised from global memory; others with neutral element." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      let global_is :: [Exp]
global_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
segment_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp
0] [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
global_bucket_is
          local_is :: [Exp]
local_is = Exp
local_subhisto_i Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
local_bucket_is
      Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
global_subhisto_i Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0)
        (VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest_local [Exp]
local_is (VName -> SubExp
Var VName
dest_global) [Exp]
global_is)
        (Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest_local ([Exp]
local_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
is) SubExp
ne [])

    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

    Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop Exp
pgtid_in_segment Exp
threads_per_segment Exp
segment_size' ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
ie -> do
      VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
i_in_segment Exp
ie

      -- We execute the bucket function once and update each histogram
      -- serially.  This also involves writing to the mapout arrays if
      -- this is the first chunk.

      Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do

        let ([SubExp]
red_res, [SubExp]
map_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LetAttrMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem KernelsMem]
[PatElemT LetAttrMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                             (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
            ([SubExp]
buckets, [SubExp]
vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [SubExp]
red_res
            perOp :: [SubExp] -> [[SubExp]]
perOp = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [SubExp] -> [[SubExp]])
-> [Int] -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest (HistOp KernelsMem -> [VName])
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs

        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
chk_i Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(PatElemT LetAttrMem, SubExp)]
-> ((PatElemT LetAttrMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetAttrMem]
-> [SubExp] -> [(PatElemT LetAttrMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetAttrMem]
map_pes [SubExp]
map_res) (((PatElemT LetAttrMem, SubExp) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT LetAttrMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetAttrMem
pe, SubExp
se) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetAttrMem -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT LetAttrMem
pe)
          ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
space_is) SubExp
se []

        [(HistOp KernelsMem,
  ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
  [SubExp])]
-> ((HistOp KernelsMem,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
     [SubExp])
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp KernelsMem]
-> [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
-> [SubExp]
-> [[SubExp]]
-> [(HistOp KernelsMem,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
     [SubExp])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 ((SegHistSlug -> HistOp KernelsMem)
-> [SegHistSlug] -> [HistOp KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp KernelsMem
slugOp [SegHistSlug]
slugs) [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
histograms [SubExp]
buckets ([SubExp] -> [[SubExp]]
perOp [SubExp]
vs)) (((HistOp KernelsMem,
   ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
   [SubExp])
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp KernelsMem,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
     [SubExp])
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          \(HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT KernelsMem
lam,
            ([(VName, VName)]
_, VName
hist_H_chk, [Exp] -> InKernelGen ()
do_op), SubExp
bucket, [SubExp]
vs') -> do

            let chk_beg :: Exp
chk_beg = Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
Imp.vi32 VName
hist_H_chk
                bucket' :: Exp
bucket' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
bucket
                dest_w' :: Exp
dest_w' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
dest_w
                bucket_in_bounds :: Exp
bucket_in_bounds = Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
dest_w' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                                   Exp
chk_beg Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                                   Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. (Exp
chk_beg Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
Imp.vi32 VName
hist_H_chk)
                bucket_is :: [Exp]
bucket_is = [Exp
thread_local_subhisto_i, Exp
bucket' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
chk_beg]
                vs_params :: [Param LetAttrMem]
vs_params = Int -> [Param LetAttrMem] -> [Param LetAttrMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LetAttrMem] -> [Param LetAttrMem])
-> [Param LetAttrMem] -> [Param LetAttrMem]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam

            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              [LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
              Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is -> do
                [(Param LetAttrMem, SubExp)]
-> ((Param LetAttrMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetAttrMem] -> [SubExp] -> [(Param LetAttrMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetAttrMem]
vs_params [SubExp]
vs') (((Param LetAttrMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetAttrMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetAttrMem
p, SubExp
v) ->
                  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetAttrMem -> VName
forall attr. Param attr -> VName
paramName Param LetAttrMem
p) [] SubExp
v [Exp]
is
                [Exp] -> InKernelGen ()
do_op ([Exp]
bucket_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)

    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal

    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Compact the multiple local memory subhistograms to result in global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      (SegHistSlug
 -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
  -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
 -> InKernelGen ())
-> (SegHistSlug
    -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests Exp
hist_H_chk [Exp]
histo_dims Exp
histo_size -> do
      Exp
bins_per_thread <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"init_per_thread" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                         Exp
histo_size Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` KernelConstants -> Exp
kernelGroupSize KernelConstants
constants

      VName
trunc_H <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"trunc_H" (Exp -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                 BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) Exp
hist_H_chk (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                 PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
-
                 Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* [Exp] -> Exp
forall a. [a] -> a
head [Exp]
histo_dims
      let trunc_histo_dims :: [Exp]
trunc_histo_dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
trunc_H SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
                             Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug))
      Exp
trunc_histo_size <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"histo_size" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
trunc_histo_dims

      String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"local_i" Exp
bins_per_thread ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
        Exp
j <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"j" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
             Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
j Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
trunc_histo_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          -- We are responsible for compacting the flat bin 'j', which
          -- we immediately unflatten.
          let local_bucket_is :: [Exp]
local_bucket_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
histo_dims Exp
j
              global_bucket_is :: [Exp]
global_bucket_is = [Exp] -> Exp
forall a. [a] -> a
head [Exp]
local_bucket_is Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:
                                 [Exp] -> [Exp]
forall a. [a] -> [a]
tail [Exp]
local_bucket_is
          [LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT KernelsMem -> [LParam KernelsMem])
-> LambdaT KernelsMem -> [LParam KernelsMem]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp (HistOp KernelsMem -> LambdaT KernelsMem)
-> HistOp KernelsMem -> LambdaT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug
          let ([VName]
global_dests, [VName]
local_dests) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, VName)]
dests
              ([Param LetAttrMem]
xparams, [Param LetAttrMem]
yparams) = Int
-> [Param LetAttrMem] -> ([Param LetAttrMem], [Param LetAttrMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
local_dests) ([Param LetAttrMem] -> ([Param LetAttrMem], [Param LetAttrMem]))
-> [Param LetAttrMem] -> ([Param LetAttrMem], [Param LetAttrMem])
forall a b. (a -> b) -> a -> b
$
                                   LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT KernelsMem -> [LParam KernelsMem])
-> LambdaT KernelsMem -> [LParam KernelsMem]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp (HistOp KernelsMem -> LambdaT KernelsMem)
-> HistOp KernelsMem -> LambdaT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Read values from subhistogram 0." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(Param LetAttrMem, VName)]
-> ((Param LetAttrMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetAttrMem] -> [VName] -> [(Param LetAttrMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetAttrMem]
xparams [VName]
local_dests) (((Param LetAttrMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetAttrMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetAttrMem
xp, VName
subhisto) ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix
            (Param LetAttrMem -> VName
forall attr. Param attr -> VName
paramName Param LetAttrMem
xp) []
            (VName -> SubExp
Var VName
subhisto) (Exp
0Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:[Exp]
local_bucket_is)

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Accumulate based on values in other subhistograms." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"subhisto_id" (Exp
num_subhistos_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1) ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
subhisto_id -> do
              [(Param LetAttrMem, VName)]
-> ((Param LetAttrMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetAttrMem] -> [VName] -> [(Param LetAttrMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetAttrMem]
yparams [VName]
local_dests) (((Param LetAttrMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetAttrMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetAttrMem
yp, VName
subhisto) ->
                VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix
                (Param LetAttrMem -> VName
forall attr. Param attr -> VName
paramName Param LetAttrMem
yp) []
                (VName -> SubExp
Var VName
subhisto) (Exp
subhisto_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1 Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
local_bucket_is)
              [Param LetAttrMem] -> Body KernelsMem -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetAttrMem]
xparams (Body KernelsMem -> InKernelGen ())
-> Body KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT KernelsMem -> Body KernelsMem)
-> LambdaT KernelsMem -> Body KernelsMem
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp (HistOp KernelsMem -> LambdaT KernelsMem)
-> HistOp KernelsMem -> LambdaT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Put final bucket value in global memory." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
            let global_is :: [Exp]
global_is =
                  (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
segment_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++
                  [Exp
group_id Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment] [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++
                  [Exp]
global_bucket_is
            [(Param LetAttrMem, VName)]
-> ((Param LetAttrMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetAttrMem] -> [VName] -> [(Param LetAttrMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetAttrMem]
xparams [VName]
global_dests) (((Param LetAttrMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetAttrMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetAttrMem
xp, VName
global_dest) ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
global_dest [Exp]
global_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetAttrMem -> VName
forall attr. Param attr -> VName
paramName Param LetAttrMem
xp) []

histKernelLocal :: VName -> Count NumGroups Imp.Exp
                -> [PatElem KernelsMem]
                -> Count NumGroups SubExp -> Count GroupSize SubExp
                -> SegSpace
                -> Imp.Exp
                -> [SegHistSlug]
                -> KernelBody KernelsMem
                -> CallKernelGen ()
histKernelLocal :: VName
-> Count NumGroups Exp
-> [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelLocal VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment [PatElem KernelsMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space Exp
hist_S [SegHistSlug]
slugs KernelBody KernelsMem
kbody = do
  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
  let num_subhistos_per_group :: Exp
num_subhistos_per_group = VName -> PrimType -> Exp
Imp.var VName
num_subhistos_per_group_var PrimType
int32

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of local subhistograms per group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
num_subhistos_per_group

  InitLocalHistograms
init_histograms <-
    VName
-> Count NumGroups Exp
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment SegSpace
space [SegHistSlug]
slugs

  String -> Exp -> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"chk_i" Exp
hist_S ((Exp -> CallKernelGen ()) -> CallKernelGen ())
-> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
chk_i ->
    VName
-> Count NumGroups Exp
-> [PatElem KernelsMem]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> InitLocalHistograms
-> Exp
-> Exp
-> CallKernelGen ()
histKernelLocalPass
    VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment [PatElem KernelsMem]
map_pes Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody
    InitLocalHistograms
init_histograms Exp
hist_S Exp
chk_i

-- | The maximum number of passes we are willing to accept for this
-- kind of atomic update.
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
  case SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
    AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
_ -> Int
3
    AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
_  -> Int
4
    AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
_ -> Int
6

localMemoryCase :: [PatElem KernelsMem]
                -> Imp.Exp
                -> SegSpace
                -> Imp.Exp -> Imp.Exp -> Imp.Exp -> Imp.Exp
                -> [SegHistSlug]
                -> KernelBody KernelsMem
                -> CallKernelGen (Imp.Exp, CallKernelGen ())
localMemoryCase :: [PatElem KernelsMem]
-> Exp
-> SegSpace
-> Exp
-> Exp
-> Exp
-> Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen (Exp, CallKernelGen ())
localMemoryCase [PatElem KernelsMem]
map_pes Exp
hist_T SegSpace
space Exp
hist_H Exp
hist_el_size Exp
hist_N Exp
_ [SegHistSlug]
slugs KernelBody KernelsMem
kbody = do
  let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
      segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
      segmented :: Bool
segmented = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims

  VName
hist_L <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"hist_L" PrimType
int32
  HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax VName
hist_L SizeClass
Imp.SizeLocalMemory

  VName
max_group_size <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"max_group_size" PrimType
int32
  HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax VName
max_group_size SizeClass
Imp.SizeGroup
  let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
max_group_size
  Count NumGroups SubExp
num_groups <- (VName -> Count NumGroups SubExp)
-> ImpM KernelsMem HostEnv HostOp VName
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (VName -> SubExp) -> VName -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) (ImpM KernelsMem HostEnv HostOp VName
 -> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM KernelsMem HostEnv HostOp VName
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$ String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_groups" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
                Exp
hist_T Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
  let num_groups' :: Count NumGroups Exp
num_groups' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> Count NumGroups SubExp -> Count NumGroups Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize Exp
group_size' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> Count GroupSize SubExp -> Count GroupSize Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count GroupSize SubExp
group_size

  let r64 :: PrimExp v -> PrimExp v
r64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64)
      t64 :: PrimExp v -> PrimExp v
t64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int32)
      i32_to_i64 :: PrimExp v -> PrimExp v
i32_to_i64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64)
      i64_to_i32 :: PrimExp v -> PrimExp v
i64_to_i32 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int64 IntType
Int32)

  -- M approximation.
  Exp
hist_m' <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_m_prime" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
             Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32)
                  (VName -> Exp
Imp.vi32 VName
hist_L Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
hist_el_size)
                  (Exp
hist_N Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups'))
             Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_H

  let hist_B :: Exp
hist_B = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

  -- M in the paper, but not adjusted for asymptotic efficiency.
  Exp
hist_M0 <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_M0" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
             BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
             BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 Exp
hist_m') Exp
hist_B

  -- Minimal sequential chunking factor.
  let q_small :: Exp
q_small = Exp
2

  -- The number of segments/histograms produced..
  Exp
hist_Nout <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_Nout" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
segment_dims

  Exp
hist_Nin <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_Nin" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> SubExp -> Exp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes

  -- Maximum M for work efficiency.
  Exp
work_asymp_M_max <-
    if Bool
segmented then do

      Exp
hist_T_hist_min <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_T_hist_min" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                         Exp -> Exp
forall v. PrimExp v -> PrimExp v
i64_to_i32 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                         BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int64)
                         (Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32_to_i64 Exp
hist_Nin Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32_to_i64 Exp
hist_Nout) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32_to_i64 Exp
hist_T)
                         Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp`
                         Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32_to_i64 Exp
hist_Nout

      -- Number of groups, rounded up.
      let r :: Exp
r = Exp
hist_T_hist_min Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_B

      String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"work_asymp_M_max" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
hist_Nin Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` (Exp
r Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H)

    else String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"work_asymp_M_max" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
         (Exp
hist_Nout Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_N) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot`
         ((Exp
q_small Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H)
          Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> Exp
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs)

  -- Number of subhistograms per result histogram.
  VName
hist_M <- String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"hist_M" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
            BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) Exp
hist_M0 Exp
work_asymp_M_max

  -- hist_M may be zero (which we'll check for below), but we need it
  -- for some divisions first, so crudely make a nonzero form.
  let hist_M_nonzero :: Exp
hist_M_nonzero = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_M

  -- "Cooperation factor" - the number of threads cooperatively
  -- working on the same (sub)histogram.
  Exp
hist_C <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_C" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
            Exp
hist_B Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_M_nonzero

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local hist_M0" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_M0
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local work asymp M max" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
work_asymp_M_max
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local C" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_C
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local B" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_B
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local M" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_M
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local memory needed" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
    Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Exp
hist_H Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_el_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
Imp.vi32 VName
hist_M

  -- local_mem_needed is what we need to keep a single bucket in local
  -- memory - this is an absolute minimum.  We can fit anything else
  -- by doing multiple passes, although more than a few is
  -- (heuristically) not efficient.
  Exp
local_mem_needed <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"local_mem_needed" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
hist_el_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
Imp.vi32 VName
hist_M
  Exp
hist_S <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_S" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ (Exp
hist_H Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
local_mem_needed) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` VName -> Exp
Imp.vi32 VName
hist_L
  let max_S :: Exp
max_S = case KernelBody KernelsMem -> Passage
bodyPassage KernelBody KernelsMem
kbody of
                Passage
MustBeSinglePass -> Exp
1
                Passage
MayBeMultiPass -> Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Exp) -> Int -> Exp
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs

  -- We only use local memory if the number of updates per histogram
  -- at least matches the histogram size, as otherwise it is not
  -- asymptotically efficient.  This mostly matters for the segmented
  -- case.
  let pick_local :: Exp
pick_local =
        Exp
hist_Nin Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>=. Exp
hist_H
        Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. (Exp
local_mem_needed Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. VName -> Exp
Imp.vi32 VName
hist_L)
        Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. (Exp
hist_S Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
max_S)
        Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
hist_C Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
hist_B
        Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. VName -> Exp
Imp.vi32 VName
hist_M Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. Exp
0

      groups_per_segment :: Count NumGroups Exp
groups_per_segment
        | Bool
segmented = Count NumGroups Exp
num_groups' Count NumGroups Exp -> Count NumGroups Exp -> Count NumGroups Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp -> Count NumGroups Exp
forall u e. e -> Count u e
Imp.Count Exp
hist_Nout
        | Bool
otherwise = Count NumGroups Exp
num_groups'

      run :: CallKernelGen ()
run = do
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using local memory" Maybe Exp
forall a. Maybe a
Nothing
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_H
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_M
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_C
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_S
        Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Groups per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment
        VName
-> Count NumGroups Exp
-> [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelLocal VName
hist_M Count NumGroups Exp
groups_per_segment [PatElem KernelsMem]
map_pes
          Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space Exp
hist_S [SegHistSlug]
slugs KernelBody KernelsMem
kbody

  (Exp, CallKernelGen ()) -> CallKernelGen (Exp, CallKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
pick_local, CallKernelGen ()
run)

-- Most of this function is not the histogram part itself, but rather
-- figuring out whether to use a local or global memory strategy, as
-- well as collapsing the subhistograms produced (which are always in
-- global memory, but their number may vary).
compileSegHist :: Pattern KernelsMem
               -> Count NumGroups SubExp -> Count GroupSize SubExp
               -> SegSpace
               -> [HistOp KernelsMem]
               -> KernelBody KernelsMem
               -> CallKernelGen ()
compileSegHist :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegHist (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp KernelsMem]
ops KernelBody KernelsMem
kbody = do
  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size

  [Exp]
dims <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp])
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

  let num_red_res :: Int
num_red_res = [HistOp KernelsMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp KernelsMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp KernelsMem -> [SubExp]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp KernelsMem]
ops)
      ([PatElemT LetAttrMem]
all_red_pes, [PatElemT LetAttrMem]
map_pes) = Int
-> [PatElemT LetAttrMem]
-> ([PatElemT LetAttrMem], [PatElemT LetAttrMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res [PatElem KernelsMem]
[PatElemT LetAttrMem]
pes
      segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims

  ([Count Bytes Exp]
op_hs, [Count Bytes Exp]
op_seg_hs, [SegHistSlug]
slugs) <- [(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
-> ([Count Bytes Exp], [Count Bytes Exp], [SegHistSlug])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
 -> ([Count Bytes Exp], [Count Bytes Exp], [SegHistSlug]))
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     [(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     ([Count Bytes Exp], [Count Bytes Exp], [SegHistSlug])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp KernelsMem
 -> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug))
-> [HistOp KernelsMem]
-> ImpM
     KernelsMem
     HostEnv
     HostOp
     [(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegSpace
-> HistOp KernelsMem
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
computeHistoUsage SegSpace
space) [HistOp KernelsMem]
ops
  Exp
h <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"h" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
Imp.unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes Exp]
op_hs
  Exp
seg_h <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"seg_h" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
Imp.unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes Exp]
op_seg_hs

  -- Check for emptyness to avoid division-by-zero.
  Exp -> CallKernelGen () -> CallKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless (Exp
seg_h Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do

    -- Maximum group size (or actual, in this case).
    let hist_B :: Exp
hist_B = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

    -- Size of a histogram.
    Exp
hist_H <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (HistOp KernelsMem -> Exp) -> [HistOp KernelsMem] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> (HistOp KernelsMem -> SubExp) -> HistOp KernelsMem -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth) [HistOp KernelsMem]
ops

    -- Size of a single histogram element.  Actually the weighted
    -- average of histogram elements in cases where we have more than
    -- one histogram operation, plus any locks.
    let lockSize :: SegHistSlug -> Maybe a
lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
                          AtomicLocking{} -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
int32
                          AtomicUpdate KernelsMem KernelEnv
_               -> Maybe a
forall a. Maybe a
Nothing
    Exp
hist_el_size <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_el_size" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(+) (Exp
h Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_H) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
                    (SegHistSlug -> Maybe Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SegHistSlug -> Maybe Exp
forall a. Num a => SegHistSlug -> Maybe a
lockSize [SegHistSlug]
slugs

    -- Input elements contributing to each histogram.
    Exp
hist_N <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_N" Exp
segment_size

    -- Compute RF as the average RF over all the histograms.
    Exp
hist_RF <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_RF" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
               [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32(SubExp -> Exp) -> (SegHistSlug -> SubExp) -> SegHistSlug -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histRaceFactor (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs)
               Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot`
               [SegHistSlug] -> Exp
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs

    let hist_T :: Exp
hist_T = Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegHist" Maybe Exp
forall a. Maybe a
Nothing
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of threads (T)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_T
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Desired group size (B)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_B
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_H
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Input elements per histogram (N)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_N
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram element size (el_size)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_el_size
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race factor (RF)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_RF
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
h
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms times segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
seg_h

    (Exp
use_local_memory, CallKernelGen ()
run_in_local_memory) <-
      [PatElem KernelsMem]
-> Exp
-> SegSpace
-> Exp
-> Exp
-> Exp
-> Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen (Exp, CallKernelGen ())
localMemoryCase [PatElem KernelsMem]
[PatElemT LetAttrMem]
map_pes Exp
hist_T SegSpace
space Exp
hist_H Exp
hist_el_size Exp
hist_N Exp
hist_RF [SegHistSlug]
slugs KernelBody KernelsMem
kbody

    Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
use_local_memory CallKernelGen ()
run_in_local_memory (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelGlobal [PatElem KernelsMem]
[PatElemT LetAttrMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody

    let pes_per_op :: [[PatElemT LetAttrMem]]
pes_per_op = [Int] -> [PatElemT LetAttrMem] -> [[PatElemT LetAttrMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp KernelsMem -> [VName]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp KernelsMem]
ops) [PatElemT LetAttrMem]
all_red_pes

    [(SegHistSlug, [PatElemT LetAttrMem], HistOp KernelsMem)]
-> ((SegHistSlug, [PatElemT LetAttrMem], HistOp KernelsMem)
    -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [[PatElemT LetAttrMem]]
-> [HistOp KernelsMem]
-> [(SegHistSlug, [PatElemT LetAttrMem], HistOp KernelsMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [[PatElemT LetAttrMem]]
pes_per_op [HistOp KernelsMem]
ops) (((SegHistSlug, [PatElemT LetAttrMem], HistOp KernelsMem)
  -> CallKernelGen ())
 -> CallKernelGen ())
-> ((SegHistSlug, [PatElemT LetAttrMem], HistOp KernelsMem)
    -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, [PatElemT LetAttrMem]
red_pes, HistOp KernelsMem
op) -> do
      let num_histos :: VName
num_histos = SegHistSlug -> VName
slugNumSubhistos SegHistSlug
slug
          subhistos :: [VName]
subhistos = (SubhistosInfo -> VName) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray ([SubhistosInfo] -> [VName]) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug

      let unitHistoCase :: CallKernelGen ()
unitHistoCase =
            -- This is OK because the memory blocks are at least as
            -- large as the ones we are supposed to use for the result.
            [(PatElemT LetAttrMem, VName)]
-> ((PatElemT LetAttrMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetAttrMem] -> [VName] -> [(PatElemT LetAttrMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetAttrMem]
red_pes [VName]
subhistos) (((PatElemT LetAttrMem, VName) -> CallKernelGen ())
 -> CallKernelGen ())
-> ((PatElemT LetAttrMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetAttrMem
pe, VName
subhisto) -> do
              VName
pe_mem <- MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> VName)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                        VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT LetAttrMem -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT LetAttrMem
pe)
              VName
subhisto_mem <- MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> VName)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                              VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
subhisto
              Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
pe_mem VName
subhisto_mem (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"

      Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (VName -> PrimType -> Exp
Imp.var VName
num_histos PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) CallKernelGen ()
unitHistoCase (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        -- For the segmented reduction, we keep the segment dimensions
        -- unchanged.  To this, we add two dimensions: one over the number
        -- of buckets, and one over the number of subhistograms.  This
        -- inner dimension is the one that is collapsed in the reduction.
        let num_buckets :: SubExp
num_buckets = HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op

        VName
bucket_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id"
        VName
subhistogram_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"subhistogram_id"
        [VName]
vector_ids <- (SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ImpM KernelsMem HostEnv HostOp VName
-> SubExp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. a -> b -> a
const (ImpM KernelsMem HostEnv HostOp VName
 -> SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp VName
-> SubExp
-> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"vector_id") ([SubExp] -> ImpM KernelsMem HostEnv HostOp [VName])
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$
                      Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op

        VName
flat_gtid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_gtid"

        let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegVirt
            segred_space :: SegSpace
segred_space =
              VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
              [(VName, SubExp)]
segment_dims [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++
              [(VName
bucket_id, SubExp
num_buckets)] [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++
              [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op) [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++
              [(VName
subhistogram_id, VName -> SubExp
Var VName
num_histos)]

        let segred_op :: SegBinOp KernelsMem
segred_op = Commutativity
-> LambdaT KernelsMem -> [SubExp] -> Shape -> SegBinOp KernelsMem
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
Commutative (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op) (HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp KernelsMem
op) Shape
forall a. Monoid a => a
mempty
        Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' ([PatElemT LetAttrMem]
-> [PatElemT LetAttrMem] -> PatternT LetAttrMem
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT LetAttrMem]
red_pes) SegLevel
lvl SegSpace
segred_space [SegBinOp KernelsMem
segred_op] (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [Exp])] -> InKernelGen ()
red_cont ->
          [(SubExp, [Exp])] -> InKernelGen ()
red_cont ([(SubExp, [Exp])] -> InKernelGen ())
-> [(SubExp, [Exp])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ ((VName -> (SubExp, [Exp])) -> [VName] -> [(SubExp, [Exp])])
-> [VName] -> (VName -> (SubExp, [Exp])) -> [(SubExp, [Exp])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [Exp])) -> [VName] -> [(SubExp, [Exp])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos ((VName -> (SubExp, [Exp])) -> [(SubExp, [Exp])])
-> (VName -> (SubExp, [Exp])) -> [(SubExp, [Exp])]
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
            (VName -> SubExp
Var VName
subhisto, (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 ([VName] -> [Exp]) -> [VName] -> [Exp]
forall a b. (a -> b) -> a -> b
$
              ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id, VName
bucket_id] [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids)

  where segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space