{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
-- | Our compilation strategy for 'SegHist' is based around avoiding
-- bin conflicts.  We do this by splitting the input into chunks, and
-- for each chunk computing a single subhistogram.  Then we combine
-- the subhistograms using an ordinary segmented reduction ('SegRed').
--
-- There are some branches around to efficiently handle the case where
-- we use only a single subhistogram (because it's large), so that we
-- respect the asymptotics, and do not copy the destination array.
--
-- We also use a heuristic strategy for computing subhistograms in
-- local memory when possible.  Given:
--
-- H: total size of histograms in bytes, including any lock arrays.
--
-- G: group size
--
-- T: number of bytes of local memory each thread can be given without
-- impacting occupancy (determined experimentally, e.g. 32).
--
-- LMAX: maximum amount of local memory per workgroup (hard limit).
--
-- We wish to compute:
--
-- COOP: cooperation level (number of threads per subhistogram)
--
-- LH: number of local memory subhistograms
--
-- We do this as:
--
-- COOP = ceil(H / T)
-- LH = ceil((G*T)/H)
-- if COOP <= G && H <= LMAX then
--   use local memory
-- else
--   use global memory

module Futhark.CodeGen.ImpGen.Kernels.SegHist
  ( compileSegHist )
  where

import Control.Monad.Except
import Data.Maybe
import Data.List (foldl', genericLength, zip4, zip6)

import Prelude hiding (quot, rem)

import Futhark.MonadFreshNames
import Futhark.Representation.ExplicitMemory
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Pass.ExplicitAllocations()
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.SegRed (compileSegRed')
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)
import Futhark.Util (chunks, mapAccumLM, splitFromEnd, takeLast)
import Futhark.Construct (fullSliceNum)

i32Toi64 :: PrimExp v -> PrimExp v
i32Toi64 :: PrimExp v -> PrimExp v
i32Toi64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64)

data SubhistosInfo = SubhistosInfo { SubhistosInfo -> VName
subhistosArray :: VName
                                   , SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
                                   }

data SegHistSlug = SegHistSlug
                   { SegHistSlug -> HistOp ExplicitMemory
slugOp :: HistOp ExplicitMemory
                   , SegHistSlug -> VName
slugNumSubhistos :: VName
                   , SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo]
                   , SegHistSlug -> AtomicUpdate ExplicitMemory KernelEnv
slugAtomicUpdate :: AtomicUpdate ExplicitMemory KernelEnv
                   }

histoSpaceUsage :: HistOp ExplicitMemory
                -> Imp.Count Imp.Bytes Imp.Exp
histoSpaceUsage :: HistOp ExplicitMemory -> Count Bytes Exp
histoSpaceUsage HistOp ExplicitMemory
op =
  [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> [Count Bytes Exp] -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$
  (Type -> Count Bytes Exp) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes Exp
typeSize (Type -> Count Bytes Exp)
-> (Type -> Type) -> Type -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
       (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp ExplicitMemory
op) (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
       (Type -> Shape -> Type
`arrayOfShape` HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp ExplicitMemory
op)) ([Type] -> [Count Bytes Exp]) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$
  LambdaT ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT ExplicitMemory -> [Type])
-> LambdaT ExplicitMemory -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp HistOp ExplicitMemory
op

-- | Figure out how much memory is needed per histogram, both
-- segmented and unsegmented,, and compute some other auxiliary
-- information.
computeHistoUsage :: SegSpace
                  -> HistOp ExplicitMemory
                  -> CallKernelGen (Imp.Count Imp.Bytes Imp.Exp,
                                    Imp.Count Imp.Bytes Imp.Exp,
                                    SegHistSlug)
computeHistoUsage :: SegSpace
-> HistOp ExplicitMemory
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
computeHistoUsage SegSpace
space HistOp ExplicitMemory
op = do
  let segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_segments :: Int
num_segments = [(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims

  -- Create names for the intermediate array memory blocks,
  -- memory block sizes, arrays, and number of subhistograms.
  VName
num_subhistos <- String -> PrimType -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"num_subhistos" PrimType
int32
  [SubhistosInfo]
subhisto_infos <- [(VName, SubExp)]
-> ((VName, SubExp)
    -> ImpM ExplicitMemory HostEnv HostOp SubhistosInfo)
-> ImpM ExplicitMemory HostEnv HostOp [SubhistosInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp ExplicitMemory -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp ExplicitMemory
op) (HistOp ExplicitMemory -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp ExplicitMemory
op)) (((VName, SubExp)
  -> ImpM ExplicitMemory HostEnv HostOp SubhistosInfo)
 -> ImpM ExplicitMemory HostEnv HostOp [SubhistosInfo])
-> ((VName, SubExp)
    -> ImpM ExplicitMemory HostEnv HostOp SubhistosInfo)
-> ImpM ExplicitMemory HostEnv HostOp [SubhistosInfo]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) -> do
    Type
dest_t <- VName -> ImpM ExplicitMemory HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
dest
    MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
-> ImpM ExplicitMemory HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
dest

    VName
subhistos_mem <-
      String -> Space -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Space -> ImpM lore r op VName
sDeclareMem (VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos_mem") (String -> Space
Space String
"device")

    let subhistos_shape :: Shape
subhistos_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++[VName -> SubExp
Var VName
num_subhistos]) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<>
                          Int -> Shape -> Shape
forall a. ArrayShape a => Int -> a -> a
stripDims Int
num_segments (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
        subhistos_membind :: MemBind
subhistos_membind = VName -> IxFun -> MemBind
ArrayIn VName
subhistos_mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
                            (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
    VName
subhistos <- String
-> PrimType
-> Shape
-> MemBind
-> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray (VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos")
                 (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t) Shape
subhistos_shape MemBind
subhistos_membind

    SubhistosInfo -> ImpM ExplicitMemory HostEnv HostOp SubhistosInfo
forall (m :: * -> *) a. Monad m => a -> m a
return (SubhistosInfo -> ImpM ExplicitMemory HostEnv HostOp SubhistosInfo)
-> SubhistosInfo
-> ImpM ExplicitMemory HostEnv HostOp SubhistosInfo
forall a b. (a -> b) -> a -> b
$ VName -> CallKernelGen () -> SubhistosInfo
SubhistosInfo VName
subhistos (CallKernelGen () -> SubhistosInfo)
-> CallKernelGen () -> SubhistosInfo
forall a b. (a -> b) -> a -> b
$ do
      let unitHistoCase :: CallKernelGen ()
unitHistoCase =
            Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLocation -> VName
memLocationName MemLocation
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
            String -> Space
Space String
"device"

          multiHistoCase :: CallKernelGen ()
multiHistoCase = do
            let num_elems :: Exp
num_elems = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(*) (VName -> PrimType -> Exp
Imp.var VName
num_subhistos PrimType
int32) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
                            (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dest_t

            let subhistos_mem_size :: Count Bytes Exp
subhistos_mem_size =
                  Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp) -> Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$
                  Count Bytes Exp -> Exp
forall u e. Count u e -> e
Imp.unCount (Exp -> Count Elements Exp
Imp.elements Exp
num_elems Count Elements Exp -> PrimType -> Count Bytes Exp
`Imp.withElemType` Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)

            VName -> Count Bytes Exp -> Space -> CallKernelGen ()
forall lore r op.
VName -> Count Bytes Exp -> Space -> ImpM lore r op ()
sAlloc_ VName
subhistos_mem Count Bytes Exp
subhistos_mem_size (Space -> CallKernelGen ()) -> Space -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"
            VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
            Type
subhistos_t <- VName -> ImpM ExplicitMemory HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
subhistos
            let slice :: Slice Exp
slice = [Exp] -> Slice Exp -> Slice Exp
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) (Slice Exp -> Slice Exp) -> Slice Exp -> Slice Exp
forall a b. (a -> b) -> a -> b
$
                        ((VName, SubExp) -> DimIndex Exp) -> [(VName, SubExp)] -> Slice Exp
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp -> DimIndex Exp
forall d. Num d => d -> d -> DimIndex d
unitSlice Exp
0 (Exp -> DimIndex Exp)
-> ((VName, SubExp) -> Exp) -> (VName, SubExp) -> DimIndex Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims Slice Exp -> Slice Exp -> Slice Exp
forall a. [a] -> [a] -> [a]
++
                        [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
0]
            VName -> Slice Exp -> SubExp -> CallKernelGen ()
forall lore r op. VName -> Slice Exp -> SubExp -> ImpM lore r op ()
sUpdate VName
subhistos Slice Exp
slice (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest

      Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (VName -> PrimType -> Exp
Imp.var VName
num_subhistos PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase

  let h :: Count Bytes Exp
h = HistOp ExplicitMemory -> Count Bytes Exp
histoSpaceUsage HistOp ExplicitMemory
op
      segmented_h :: Count Bytes Exp
segmented_h = Count Bytes Exp
h Count Bytes Exp -> Count Bytes Exp -> Count Bytes Exp
forall a. Num a => a -> a -> a
* [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> Count Bytes Exp) -> [SubExp] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp)
-> (SubExp -> Exp) -> SubExp -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Count Bytes Exp]) -> [SubExp] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM ExplicitMemory HostEnv HostOp HostEnv
-> ImpM ExplicitMemory HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv

  (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
forall (m :: * -> *) a. Monad m => a -> m a
return (Count Bytes Exp
h,
          Count Bytes Exp
segmented_h,
          HistOp ExplicitMemory
-> VName
-> [SubhistosInfo]
-> AtomicUpdate ExplicitMemory KernelEnv
-> SegHistSlug
SegHistSlug HistOp ExplicitMemory
op VName
num_subhistos [SubhistosInfo]
subhisto_infos (AtomicUpdate ExplicitMemory KernelEnv -> SegHistSlug)
-> AtomicUpdate ExplicitMemory KernelEnv -> SegHistSlug
forall a b. (a -> b) -> a -> b
$
          AtomicBinOp
-> LambdaT ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv
atomicUpdateLocking AtomicBinOp
atomics (LambdaT ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv)
-> LambdaT ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp HistOp ExplicitMemory
op)

prepareAtomicUpdateGlobal :: Maybe Locking -> [VName] -> SegHistSlug
                          -> CallKernelGen (Maybe Locking,
                                            [Imp.Exp] -> InKernelGen ())
prepareAtomicUpdateGlobal :: Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug =
  -- We need a separate lock array if the operators are not all of a
  -- particularly simple form that permits pure atomic operations.
  case (Maybe Locking
l, SegHistSlug -> AtomicUpdate ExplicitMemory KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
    (Maybe Locking
_, AtomicPrim DoAtomicUpdate ExplicitMemory KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate ExplicitMemory KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
    (Maybe Locking
_, AtomicCAS DoAtomicUpdate ExplicitMemory KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate ExplicitMemory KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
    (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)
    (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f) -> do
      -- The number of locks used here is too low, but since we are
      -- currently forced to inline a huge list, I'm keeping it down
      -- for now.  Some quick experiments suggested that it has little
      -- impact anyway (maybe the locking case is just too slow).
      --
      -- A fun solution would also be to use a simple hashing
      -- algorithm to ensure good distribution of locks.
      let num_locks :: Int
num_locks = Int
100151
          dims :: [Exp]
dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$
                 Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug)) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++
                 [ VName -> SubExp
Var (SegHistSlug -> VName
slugNumSubhistos SegHistSlug
slug)
                 , HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug)]
      VName
locks <-
        String
-> Space
-> PrimType
-> ArrayContents
-> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"hist_locks" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ArrayContents -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
        Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
      let l' :: Locking
l' = VName -> Exp -> Exp -> Exp -> ([Exp] -> [Exp]) -> Locking
Locking VName
locks Exp
0 Exp
1 Exp
0 (Exp -> [Exp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> [Exp]) -> ([Exp] -> Exp) -> [Exp] -> [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (Exp -> Exp) -> ([Exp] -> Exp) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> [Exp] -> Exp
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [Exp]
dims)
      (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)

-- | Some kernel bodies are not safe (or efficient) to execute
-- multiple times.
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
(Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool) -> Eq Passage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c== :: Passage -> Passage -> Bool
Eq, Eq Passage
Eq Passage
-> (Passage -> Passage -> Ordering)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Passage)
-> (Passage -> Passage -> Passage)
-> Ord Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmax :: Passage -> Passage -> Passage
>= :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c< :: Passage -> Passage -> Bool
compare :: Passage -> Passage -> Ordering
$ccompare :: Passage -> Passage -> Ordering
$cp1Ord :: Eq Passage
Ord)

bodyPassage :: KernelBody ExplicitMemory -> Passage
bodyPassage :: KernelBody ExplicitMemory -> Passage
bodyPassage KernelBody ExplicitMemory
kbody
  | Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== KernelBody (Aliases ExplicitMemory) -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody (KernelBody ExplicitMemory -> KernelBody (Aliases ExplicitMemory)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody KernelBody ExplicitMemory
kbody) =
      Passage
MayBeMultiPass
  | Bool
otherwise =
      Passage
MustBeSinglePass

prepareIntermediateArraysGlobal :: Passage -> Imp.Exp -> Imp.Exp -> [SegHistSlug]
                                -> CallKernelGen
                                   (Imp.Exp,
                                    [[Imp.Exp] -> InKernelGen ()])
prepareIntermediateArraysGlobal :: Passage
-> Exp
-> Exp
-> [SegHistSlug]
-> CallKernelGen (Exp, [[Exp] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage Exp
hist_T Exp
hist_N [SegHistSlug]
slugs = do
  -- The paper formulae assume there is only one histogram, but in our
  -- implementation there can be multiple that have been horisontally
  -- fused.  We do a bit of trickery with summings and averages to
  -- pretend there is really only one.  For the case of a single
  -- histogram, the actual calculations should be the same as in the
  -- paper.

  -- The sum of all Hs.
  Exp
hist_H <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> ([Exp] -> Exp)
-> [Exp]
-> ImpM ExplicitMemory HostEnv HostOp Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Exp] -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> ImpM ExplicitMemory HostEnv HostOp [Exp]
-> ImpM ExplicitMemory HostEnv HostOp Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SegHistSlug -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SegHistSlug] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> (SegHistSlug -> SubExp)
-> SegHistSlug
-> ImpM ExplicitMemory HostEnv HostOp Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp ExplicitMemory -> SubExp)
-> (SegHistSlug -> HistOp ExplicitMemory) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp ExplicitMemory
slugOp) [SegHistSlug]
slugs

  Exp
hist_RF <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_RF" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
    [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (Exp -> Exp) -> (SegHistSlug -> Exp) -> SegHistSlug -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> (SegHistSlug -> SubExp) -> SegHistSlug -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histRaceFactor (HistOp ExplicitMemory -> SubExp)
-> (SegHistSlug -> HistOp ExplicitMemory) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp ExplicitMemory
slugOp) [SegHistSlug]
slugs)
    Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 ([SegHistSlug] -> Exp
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs)

  Exp
hist_el_size <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_el_size" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Exp
slugElAvgSize [SegHistSlug]
slugs

  Exp
hist_C_max <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_C_max" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
    BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMin FloatType
Float64) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_H Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_k_ct_min

  Exp
hist_M_min <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_M_min" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
    BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_C_max

  -- Querying L2 cache size is not reliable.  Instead we provide a
  -- tunable knob with a hopefully sane default.
  let hist_L2_def :: Int32
hist_L2_def = Int32
4 Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
1024 Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
1024
  VName
hist_L2 <- String -> PrimType -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"L2_size" PrimType
int32
  Maybe Name
entry <- ImpM ExplicitMemory HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  -- Equivalent to F_L2*L2 in paper.
  HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
hist_L2
    (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
entry (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (VName -> String
forall a. Pretty a => a -> String
pretty VName
hist_L2)) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
    Name -> Int32 -> SizeClass
Imp.SizeBespoke (String -> Name
nameFromString String
"L2_for_histogram") Int32
hist_L2_def

  let hist_L2_ln_sz :: Exp
hist_L2_ln_sz = Exp
16Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
4 -- L2 cache line size approximation

  Exp
hist_RACE_exp <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_RACE_exp" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
    BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMax FloatType
Float64) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
    (Exp
hist_k_RF Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_RF) Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/
    (Exp
hist_L2_ln_sz Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_el_size)

  VName
hist_S <- String -> PrimType -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"hist_S" PrimType
int32

  -- For sparse histograms (H exceeds N) we only want a single chunk.
  Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
hist_N Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
hist_H)
    (VName
hist_S VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    VName
hist_S VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<--
    case Passage
passage of
      Passage
MayBeMultiPass ->
        (Exp
hist_M_min Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_el_size) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp`
        Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 (Exp
hist_F_L2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (VName -> Exp
Imp.vi32 VName
hist_L2) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_RACE_exp)
      Passage
MustBeSinglePass ->
        Exp
1

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race expansion factor (RACE^exp)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_RACE_exp
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_S

  [[Exp] -> InKernelGen ()]
histograms <- (Maybe Locking, [[Exp] -> InKernelGen ()])
-> [[Exp] -> InKernelGen ()]
forall a b. (a, b) -> b
snd ((Maybe Locking, [[Exp] -> InKernelGen ()])
 -> [[Exp] -> InKernelGen ()])
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     (Maybe Locking, [[Exp] -> InKernelGen ()])
-> ImpM ExplicitMemory HostEnv HostOp [[Exp] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe Locking
 -> SegHistSlug
 -> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ()))
-> Maybe Locking
-> [SegHistSlug]
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     (Maybe Locking, [[Exp] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (Exp
-> Exp
-> Exp
-> Exp
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
onOp (VName -> Exp
Imp.vi32 VName
hist_L2) Exp
hist_M_min (VName -> Exp
Imp.vi32 VName
hist_S) Exp
hist_RACE_exp) Maybe Locking
forall a. Maybe a
Nothing [SegHistSlug]
slugs

  (Exp, [[Exp] -> InKernelGen ()])
-> CallKernelGen (Exp, [[Exp] -> InKernelGen ()])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Exp
Imp.vi32 VName
hist_S, [[Exp] -> InKernelGen ()]
histograms)
  where
    hist_k_ct_min :: Exp
hist_k_ct_min = Exp
2 -- Chosen experimentally
    hist_k_RF :: Exp
hist_k_RF = Exp
0.75 -- Chosen experimentally
    hist_F_L2 :: Exp
hist_F_L2 = Exp
0.4 -- Chosen experimentally

    r64 :: PrimExp v -> PrimExp v
r64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64)
    t64 :: PrimExp v -> PrimExp v
t64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int32)

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElAvgSize :: SegHistSlug -> Exp
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp ExplicitMemory
op VName
_ [SubhistosInfo]
_ AtomicUpdate ExplicitMemory KernelEnv
do_op) =
      case AtomicUpdate ExplicitMemory KernelEnv
do_op of
        AtomicLocking{} ->
          SegHistSlug -> Exp
slugElSize SegHistSlug
slug Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` (Exp
1Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+[Type] -> Exp
forall i a. Num i => [a] -> i
genericLength (LambdaT ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp HistOp ExplicitMemory
op)))
        AtomicUpdate ExplicitMemory KernelEnv
_ ->
          SegHistSlug -> Exp
slugElSize SegHistSlug
slug Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` [Type] -> Exp
forall i a. Num i => [a] -> i
genericLength (LambdaT ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp HistOp ExplicitMemory
op))

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElSize :: SegHistSlug -> Exp
slugElSize (SegHistSlug HistOp ExplicitMemory
op VName
_ [SubhistosInfo]
_ AtomicUpdate ExplicitMemory KernelEnv
do_op) =
      case AtomicUpdate ExplicitMemory KernelEnv
do_op of
        AtomicLocking{} ->
          Count Bytes Exp -> Exp
forall u e. Count u e -> e
unCount
          ([Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> [Count Bytes Exp] -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ (Type -> Count Bytes Exp) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes Exp
typeSize (Type -> Count Bytes Exp)
-> (Type -> Type) -> Type -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp ExplicitMemory
op)) ([Type] -> [Count Bytes Exp]) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$
           PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: LambdaT ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp HistOp ExplicitMemory
op))
        AtomicUpdate ExplicitMemory KernelEnv
_ ->
          Count Bytes Exp -> Exp
forall u e. Count u e -> e
unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> [Count Bytes Exp] -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$
          (Type -> Count Bytes Exp) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes Exp
typeSize (Type -> Count Bytes Exp)
-> (Type -> Type) -> Type -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp ExplicitMemory
op)) ([Type] -> [Count Bytes Exp]) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$
          LambdaT ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp HistOp ExplicitMemory
op)

    onOp :: Exp
-> Exp
-> Exp
-> Exp
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
onOp Exp
hist_L2 Exp
hist_M_min Exp
hist_S Exp
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
      let SegHistSlug HistOp ExplicitMemory
op VName
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate ExplicitMemory KernelEnv
do_op = SegHistSlug
slug
      Exp
hist_H <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp ExplicitMemory
op

      Exp
hist_H_chk <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H_chk" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                    Exp
hist_H Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_S

      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Chunk size (H_chk)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_H_chk

      Exp
hist_k_max <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_k_max" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
        BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMin FloatType
Float64)
        (Exp
hist_F_L2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_L2 Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (SegHistSlug -> Exp
slugElSize SegHistSlug
slug)) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_RACE_exp)
        (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_N)
        Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T

      Exp
hist_u <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_u" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                case AtomicUpdate ExplicitMemory KernelEnv
do_op of
                  AtomicPrim{} -> Exp
2
                  AtomicUpdate ExplicitMemory KernelEnv
_            -> Exp
1

      Exp
hist_C <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_C" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMin FloatType
Float64) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (Exp
hist_u Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk) Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_k_max

      -- Number of subhistograms per result histogram.
      Exp
hist_M <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_M" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
        case SegHistSlug -> AtomicUpdate ExplicitMemory KernelEnv
slugAtomicUpdate SegHistSlug
slug of
          AtomicPrim{} -> Exp
1
          AtomicUpdate ExplicitMemory KernelEnv
_ -> BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
hist_M_min (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
               Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_C

      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Elements/thread in L2 cache (k_max)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_k_max
      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_M
      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_C

      -- num_subhistos is the variable we use to communicate back.
      VName
num_subhistos VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
hist_M

      -- Initialise sub-histograms.
      --
      -- If hist_M is 1, then we just reuse the original
      -- destination.  The idea is to avoid a copy if we are writing a
      -- small number of values into a very large prior histogram.
      [VName]
dests <- [(VName, SubhistosInfo)]
-> ((VName, SubhistosInfo)
    -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubhistosInfo] -> [(VName, SubhistosInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp ExplicitMemory -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp ExplicitMemory
op) [SubhistosInfo]
subhisto_info) (((VName, SubhistosInfo)
  -> ImpM ExplicitMemory HostEnv HostOp VName)
 -> ImpM ExplicitMemory HostEnv HostOp [VName])
-> ((VName, SubhistosInfo)
    -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubhistosInfo
info) -> do
        MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
-> ImpM ExplicitMemory HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
dest

        VName
sub_mem <- (MemLocation -> VName)
-> ImpM ExplicitMemory HostEnv HostOp MemLocation
-> ImpM ExplicitMemory HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemLocation -> VName
memLocationName (ImpM ExplicitMemory HostEnv HostOp MemLocation
 -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp MemLocation
-> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
                   ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
-> ImpM ExplicitMemory HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                   VName -> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (SubhistosInfo -> VName
subhistosArray SubhistosInfo
info)

        let unitHistoCase :: CallKernelGen ()
unitHistoCase =
              Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
              VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLocation -> VName
memLocationName MemLocation
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
              String -> Space
Space String
"device"

            multiHistoCase :: CallKernelGen ()
multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info

        Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
hist_M Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase

        VName -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM ExplicitMemory HostEnv HostOp VName)
-> VName -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info

      (Maybe Locking
l', [Exp] -> InKernelGen ()
do_op') <- Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug

      (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l', [Exp] -> InKernelGen ()
do_op')

histKernelGlobalPass :: [PatElem ExplicitMemory]
                     -> Count NumGroups Imp.Exp
                     -> Count GroupSize Imp.Exp
                     -> SegSpace
                     -> [SegHistSlug]
                     -> KernelBody ExplicitMemory
                     -> [[Imp.Exp] -> InKernelGen ()]
                     -> Imp.Exp -> Imp.Exp
                     -> CallKernelGen ()
histKernelGlobalPass :: [PatElem ExplicitMemory]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> [[Exp] -> InKernelGen ()]
-> Exp
-> Exp
-> CallKernelGen ()
histKernelGlobalPass [PatElem ExplicitMemory]
map_pes Count NumGroups Exp
num_groups Count GroupSize Exp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody [[Exp] -> InKernelGen ()]
histograms Exp
hist_S Exp
chk_i = do

  let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      space_sizes_64 :: [Exp]
space_sizes_64 = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32Toi64 (Exp -> Exp) -> (SubExp -> Exp) -> SubExp -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
space_sizes
      total_w_64 :: Exp
total_w_64 = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
space_sizes_64

  [Exp]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp ExplicitMemory -> SubExp)
-> (SegHistSlug -> HistOp ExplicitMemory) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp ExplicitMemory
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
 -> ImpM ExplicitMemory HostEnv HostOp [Exp])
-> (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall a b. (a -> b) -> a -> b
$ \SubExp
w -> do
    Exp
w' <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
w
    String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H_chk" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
w' Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_S

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_global" Count NumGroups Exp
num_groups Count GroupSize Exp
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

    -- Compute subhistogram index for each thread, per histogram.
    [Exp]
subhisto_inds <- [SegHistSlug]
-> (SegHistSlug -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegHistSlug]
slugs ((SegHistSlug -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
 -> ImpM ExplicitMemory KernelEnv KernelOp [Exp])
-> (SegHistSlug -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug ->
      String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"subhisto_ind" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
      KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot`
      (KernelConstants -> Exp
kernelNumThreads KernelConstants
constants Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` VName -> Exp
Imp.vi32 (SegHistSlug -> VName
slugNumSubhistos SegHistSlug
slug))

    -- Loop over flat offsets into the input and output.  The
    -- calculation is done with 64-bit integers to avoid overflow,
    -- but the final unflattened segment indexes are 32 bit.
    let gtid :: Exp
gtid = Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32Toi64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
        num_threads :: Exp
num_threads = Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32Toi64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelNumThreads KernelConstants
constants
    Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop Exp
gtid Exp
num_threads Exp
total_w_64 ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
offset -> do

      -- Construct segment indices.
      let setIndex :: VName -> Exp -> ImpM lore r op ()
setIndex VName
v Exp
e = do VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
v PrimType
int32
                            VName
v VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
e
      (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
setIndex [VName]
space_is ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        (Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int64 IntType
Int32)) ([Exp] -> [Exp]) -> [Exp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
space_sizes_64 Exp
offset

      -- We execute the bucket function once and update each histogram serially.
      -- We apply the bucket function if j=offset+ltid is less than
      -- num_elements.  This also involves writing to the mapout
      -- arrays.
      let input_in_bounds :: Exp
input_in_bounds = Exp
offset Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
total_w_64

      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
input_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT (MemInfo SubExp NoUniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
kbody

        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes [KernelResult]
map_res) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, KernelResult
res) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe)
          (((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Exp
Imp.vi32 (VName -> Exp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
          (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

        let ([KernelResult]
buckets, [KernelResult]
vs) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [KernelResult]
red_res
            perOp :: [KernelResult] -> [[KernelResult]]
perOp = [Int] -> [KernelResult] -> [[KernelResult]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [KernelResult] -> [[KernelResult]])
-> [Int] -> [KernelResult] -> [[KernelResult]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> [VName]
forall lore. HistOp lore -> [VName]
histDest (HistOp ExplicitMemory -> [VName])
-> (SegHistSlug -> HistOp ExplicitMemory) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp ExplicitMemory
slugOp) [SegHistSlug]
slugs

        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(HistOp ExplicitMemory, [Exp] -> InKernelGen (), KernelResult,
  [KernelResult], Exp, Exp)]
-> ((HistOp ExplicitMemory, [Exp] -> InKernelGen (), KernelResult,
     [KernelResult], Exp, Exp)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp ExplicitMemory]
-> [[Exp] -> InKernelGen ()]
-> [KernelResult]
-> [[KernelResult]]
-> [Exp]
-> [Exp]
-> [(HistOp ExplicitMemory, [Exp] -> InKernelGen (), KernelResult,
     [KernelResult], Exp, Exp)]
forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 ((SegHistSlug -> HistOp ExplicitMemory)
-> [SegHistSlug] -> [HistOp ExplicitMemory]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp ExplicitMemory
slugOp [SegHistSlug]
slugs) [[Exp] -> InKernelGen ()]
histograms [KernelResult]
buckets ([KernelResult] -> [[KernelResult]]
perOp [KernelResult]
vs) [Exp]
subhisto_inds [Exp]
hist_H_chks) (((HistOp ExplicitMemory, [Exp] -> InKernelGen (), KernelResult,
   [KernelResult], Exp, Exp)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp ExplicitMemory, [Exp] -> InKernelGen (), KernelResult,
     [KernelResult], Exp, Exp)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          \(HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT ExplicitMemory
lam,
            [Exp] -> InKernelGen ()
do_op, KernelResult
bucket, [KernelResult]
vs', Exp
subhisto_ind, Exp
hist_H_chk) -> do

            let chk_beg :: Exp
chk_beg = Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk
                bucket' :: Exp
bucket' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> SubExp -> Exp
forall a b. (a -> b) -> a -> b
$ KernelResult -> SubExp
kernelResultSubExp KernelResult
bucket
                dest_w' :: Exp
dest_w' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
dest_w
                bucket_in_bounds :: Exp
bucket_in_bounds = Exp
chk_beg Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                                   Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. (Exp
chk_beg Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
hist_H_chk) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                                   Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
dest_w'
                vs_params :: [Param (MemInfo SubExp NoUniqueness MemBind)]
vs_params = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
vs') ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ LambdaT ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT ExplicitMemory
lam

            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              let bucket_is :: [Exp]
bucket_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is) [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++
                              [Exp
subhisto_ind, Exp
bucket']
              [LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams ([LParam ExplicitMemory] -> InKernelGen ())
-> [LParam ExplicitMemory] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT ExplicitMemory
lam
              Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is -> do
                [(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
vs_params [KernelResult]
vs') (((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, KernelResult
res) ->
                  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) [Exp]
is
                [Exp] -> InKernelGen ()
do_op ([Exp]
bucket_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)


histKernelGlobal :: [PatElem ExplicitMemory]
                 -> Count NumGroups SubExp -> Count GroupSize SubExp
                 -> SegSpace
                 -> [SegHistSlug]
                 -> KernelBody ExplicitMemory
                 -> CallKernelGen ()
histKernelGlobal :: [PatElem ExplicitMemory]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> CallKernelGen ()
histKernelGlobal [PatElem ExplicitMemory]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody = do
  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
  let ([VName]
_space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_threads :: Exp
num_threads = Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using global memory" Maybe Exp
forall a. Maybe a
Nothing

  (Exp
hist_S, [[Exp] -> InKernelGen ()]
histograms) <-
    Passage
-> Exp
-> Exp
-> [SegHistSlug]
-> CallKernelGen (Exp, [[Exp] -> InKernelGen ()])
prepareIntermediateArraysGlobal (KernelBody ExplicitMemory -> Passage
bodyPassage KernelBody ExplicitMemory
kbody)
    Exp
num_threads (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> SubExp -> Exp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes) [SegHistSlug]
slugs

  String -> Exp -> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"chk_i" Exp
hist_S ((Exp -> CallKernelGen ()) -> CallKernelGen ())
-> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
chk_i ->
    [PatElem ExplicitMemory]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> [[Exp] -> InKernelGen ()]
-> Exp
-> Exp
-> CallKernelGen ()
histKernelGlobalPass [PatElem ExplicitMemory]
map_pes Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' SegSpace
space [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody
    [[Exp] -> InKernelGen ()]
histograms Exp
hist_S Exp
chk_i

type InitLocalHistograms = [([VName],
                              SubExp ->
                              InKernelGen ([VName],
                                            [Imp.Exp] -> InKernelGen ()))]

prepareIntermediateArraysLocal :: VName
                               -> Count NumGroups Imp.Exp
                               -> SegSpace -> [SegHistSlug]
                               -> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: VName
-> Count NumGroups Exp
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal VName
num_subhistos_per_group Count NumGroups Exp
groups_per_segment SegSpace
space [SegHistSlug]
slugs = do
  Exp
num_segments <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"num_segments" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                  [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  (SegHistSlug
 -> ImpM
      ExplicitMemory
      HostEnv
      HostOp
      ([VName],
       SubExp
       -> ImpM
            ExplicitMemory
            KernelEnv
            KernelOp
            ([VName], [Exp] -> InKernelGen ())))
-> [SegHistSlug] -> CallKernelGen InitLocalHistograms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Exp
-> SegHistSlug
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           ([VName], [Exp] -> InKernelGen ()))
onOp Exp
num_segments) [SegHistSlug]
slugs
  where
    onOp :: Exp
-> SegHistSlug
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           ([VName], [Exp] -> InKernelGen ()))
onOp Exp
num_segments (SegHistSlug HistOp ExplicitMemory
op VName
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate ExplicitMemory KernelEnv
do_op) = do

      VName
num_subhistos VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_segments

      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of subhistograms in global memory" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
        Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
num_subhistos

      SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
mk_op <-
        case AtomicUpdate ExplicitMemory KernelEnv
do_op of
          AtomicPrim DoAtomicUpdate ExplicitMemory KernelEnv
f -> (SubExp
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      (DoAtomicUpdate ExplicitMemory KernelEnv))
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           (DoAtomicUpdate ExplicitMemory KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
  -> ImpM
       ExplicitMemory
       KernelEnv
       KernelOp
       (DoAtomicUpdate ExplicitMemory KernelEnv))
 -> ImpM
      ExplicitMemory
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            ExplicitMemory
            KernelEnv
            KernelOp
            (DoAtomicUpdate ExplicitMemory KernelEnv)))
-> (SubExp
    -> ImpM
         ExplicitMemory
         KernelEnv
         KernelOp
         (DoAtomicUpdate ExplicitMemory KernelEnv))
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           (DoAtomicUpdate ExplicitMemory KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM
  ExplicitMemory
  KernelEnv
  KernelOp
  (DoAtomicUpdate ExplicitMemory KernelEnv)
-> SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
forall a b. a -> b -> a
const (ImpM
   ExplicitMemory
   KernelEnv
   KernelOp
   (DoAtomicUpdate ExplicitMemory KernelEnv)
 -> SubExp
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      (DoAtomicUpdate ExplicitMemory KernelEnv))
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
-> SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate ExplicitMemory KernelEnv
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate ExplicitMemory KernelEnv
f
          AtomicCAS DoAtomicUpdate ExplicitMemory KernelEnv
f -> (SubExp
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      (DoAtomicUpdate ExplicitMemory KernelEnv))
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           (DoAtomicUpdate ExplicitMemory KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
  -> ImpM
       ExplicitMemory
       KernelEnv
       KernelOp
       (DoAtomicUpdate ExplicitMemory KernelEnv))
 -> ImpM
      ExplicitMemory
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            ExplicitMemory
            KernelEnv
            KernelOp
            (DoAtomicUpdate ExplicitMemory KernelEnv)))
-> (SubExp
    -> ImpM
         ExplicitMemory
         KernelEnv
         KernelOp
         (DoAtomicUpdate ExplicitMemory KernelEnv))
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           (DoAtomicUpdate ExplicitMemory KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM
  ExplicitMemory
  KernelEnv
  KernelOp
  (DoAtomicUpdate ExplicitMemory KernelEnv)
-> SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
forall a b. a -> b -> a
const (ImpM
   ExplicitMemory
   KernelEnv
   KernelOp
   (DoAtomicUpdate ExplicitMemory KernelEnv)
 -> SubExp
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      (DoAtomicUpdate ExplicitMemory KernelEnv))
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
-> SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate ExplicitMemory KernelEnv
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate ExplicitMemory KernelEnv
f
          AtomicLocking Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f -> (SubExp
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      (DoAtomicUpdate ExplicitMemory KernelEnv))
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           (DoAtomicUpdate ExplicitMemory KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
  -> ImpM
       ExplicitMemory
       KernelEnv
       KernelOp
       (DoAtomicUpdate ExplicitMemory KernelEnv))
 -> ImpM
      ExplicitMemory
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            ExplicitMemory
            KernelEnv
            KernelOp
            (DoAtomicUpdate ExplicitMemory KernelEnv)))
-> (SubExp
    -> ImpM
         ExplicitMemory
         KernelEnv
         KernelOp
         (DoAtomicUpdate ExplicitMemory KernelEnv))
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           (DoAtomicUpdate ExplicitMemory KernelEnv))
forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
            let lock_shape :: Shape
lock_shape =
                  [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
num_subhistos_per_group SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
                  Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp ExplicitMemory
op) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++
                  [SubExp
hist_H_chk]

            [Exp]
dims <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp])
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
lock_shape

            VName
locks <- String
-> PrimType
-> Shape
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"locks" PrimType
int32 Shape
lock_shape (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
dims (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
locks [Exp]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

            DoAtomicUpdate ExplicitMemory KernelEnv
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return (DoAtomicUpdate ExplicitMemory KernelEnv
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      (DoAtomicUpdate ExplicitMemory KernelEnv))
-> DoAtomicUpdate ExplicitMemory KernelEnv
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
forall a b. (a -> b) -> a -> b
$ Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f (Locking -> DoAtomicUpdate ExplicitMemory KernelEnv)
-> Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Exp -> Exp -> ([Exp] -> [Exp]) -> Locking
Locking VName
locks Exp
0 Exp
1 Exp
0 [Exp] -> [Exp]
forall a. a -> a
id

      -- Initialise local-memory sub-histograms.  These are
      -- represented as two-dimensional arrays.
      let init_local_subhistos :: SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     ([VName], [Exp] -> InKernelGen ())
init_local_subhistos SubExp
hist_H_chk = do
            [VName]
local_subhistos <-
              [Type]
-> (Type -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp ExplicitMemory -> [Type]
forall lore. HistOp lore -> [Type]
histType HistOp ExplicitMemory
op) ((Type -> ImpM ExplicitMemory KernelEnv KernelOp VName)
 -> ImpM ExplicitMemory KernelEnv KernelOp [VName])
-> (Type -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
                let sub_local_shape :: Shape
sub_local_shape =
                      [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_subhistos_per_group] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<>
                      (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t Shape -> SubExp -> Shape
forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` SubExp
hist_H_chk)
                String
-> PrimType
-> Shape
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"subhistogram_local"
                  (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) Shape
sub_local_shape (String -> Space
Space String
"local")

            DoAtomicUpdate ExplicitMemory KernelEnv
do_op' <- SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     (DoAtomicUpdate ExplicitMemory KernelEnv)
mk_op SubExp
hist_H_chk

            ([VName], [Exp] -> InKernelGen ())
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     ([VName], [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
local_subhistos, DoAtomicUpdate ExplicitMemory KernelEnv
do_op' (String -> Space
Space String
"local") [VName]
local_subhistos)

      -- Initialise global-memory sub-histograms.
      [VName]
glob_subhistos <- [SubhistosInfo]
-> (SubhistosInfo -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubhistosInfo]
subhisto_info ((SubhistosInfo -> ImpM ExplicitMemory HostEnv HostOp VName)
 -> ImpM ExplicitMemory HostEnv HostOp [VName])
-> (SubhistosInfo -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubhistosInfo
info -> do
        SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
        VName -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM ExplicitMemory HostEnv HostOp VName)
-> VName -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info

      ([VName],
 SubExp
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      ([VName], [Exp] -> InKernelGen ()))
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           ([VName], [Exp] -> InKernelGen ()))
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
glob_subhistos, SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     ([VName], [Exp] -> InKernelGen ())
init_local_subhistos)

histKernelLocalPass :: VName -> Count NumGroups Imp.Exp
                    -> [PatElem ExplicitMemory]
                    -> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
                    -> SegSpace
                    -> [SegHistSlug]
                    -> KernelBody ExplicitMemory
                    -> InitLocalHistograms -> Imp.Exp -> Imp.Exp
                    -> CallKernelGen ()
histKernelLocalPass :: VName
-> Count NumGroups Exp
-> [PatElem ExplicitMemory]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> InitLocalHistograms
-> Exp
-> Exp
-> CallKernelGen ()
histKernelLocalPass VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment [PatElem ExplicitMemory]
map_pes Count NumGroups Exp
num_groups Count GroupSize Exp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody
                    InitLocalHistograms
init_histograms Exp
hist_S Exp
chk_i = do
  let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      segment_is :: [VName]
segment_is = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is
      segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
      (VName
i_in_segment, SubExp
segment_size) = [(VName, SubExp)] -> (VName, SubExp)
forall a. [a] -> a
last ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_subhistos_per_group :: Exp
num_subhistos_per_group = VName -> PrimType -> Exp
Imp.var VName
num_subhistos_per_group_var PrimType
int32

  Exp
segment_size' <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
segment_size

  Exp
num_segments <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"num_segments" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                  [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
segment_dims

  [VName]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp ExplicitMemory -> SubExp)
-> (SegHistSlug -> HistOp ExplicitMemory) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp ExplicitMemory
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM ExplicitMemory HostEnv HostOp VName)
 -> ImpM ExplicitMemory HostEnv HostOp [VName])
-> (SubExp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubExp
w -> do
    Exp
w' <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
w
    String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"hist_H_chk" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Exp
w' Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_S

  String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_local" Count NumGroups Exp
num_groups Count GroupSize Exp
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_segments) ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
group_id_var -> do

    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

    let group_id :: Exp
group_id = VName -> Exp
Imp.vi32 VName
group_id_var

    Exp
flat_segment_id <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"flat_segment_id" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_id Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment
    Exp
gid_in_segment <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"gid_in_segment" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_id Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment
    -- This pgtid is kind of a "virtualised physical" gtid - not the
    -- same thing as the gtid used for the SegHist itself.
    Exp
pgtid_in_segment <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"pgtid_in_segment" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
      Exp
gid_in_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
    Exp
threads_per_segment <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"threads_per_segment" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
      Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants

    -- Set segment indices.
    (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
segment_is ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
segment_dims) Exp
flat_segment_id

    [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
histograms <- [(([VName],
   SubExp
   -> ImpM
        ExplicitMemory
        KernelEnv
        KernelOp
        ([VName], [Exp] -> InKernelGen ())),
  VName)]
-> ((([VName],
      SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           ([VName], [Exp] -> InKernelGen ())),
     VName)
    -> ImpM
         ExplicitMemory
         KernelEnv
         KernelOp
         ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (InitLocalHistograms
-> [VName]
-> [(([VName],
      SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           ([VName], [Exp] -> InKernelGen ())),
     VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip InitLocalHistograms
init_histograms [VName]
hist_H_chks) (((([VName],
    SubExp
    -> ImpM
         ExplicitMemory
         KernelEnv
         KernelOp
         ([VName], [Exp] -> InKernelGen ())),
   VName)
  -> ImpM
       ExplicitMemory
       KernelEnv
       KernelOp
       ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      [([(VName, VName)], VName, [Exp] -> InKernelGen ())])
-> ((([VName],
      SubExp
      -> ImpM
           ExplicitMemory
           KernelEnv
           KernelOp
           ([VName], [Exp] -> InKernelGen ())),
     VName)
    -> ImpM
         ExplicitMemory
         KernelEnv
         KernelOp
         ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
forall a b. (a -> b) -> a -> b
$
                  \(([VName]
glob_subhistos, SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     ([VName], [Exp] -> InKernelGen ())
init_local_subhistos), VName
hist_H_chk) -> do
      ([VName]
local_subhistos, [Exp] -> InKernelGen ()
do_op) <- SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     ([VName], [Exp] -> InKernelGen ())
init_local_subhistos (SubExp
 -> ImpM
      ExplicitMemory
      KernelEnv
      KernelOp
      ([VName], [Exp] -> InKernelGen ()))
-> SubExp
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     ([VName], [Exp] -> InKernelGen ())
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
hist_H_chk
      ([(VName, VName)], VName, [Exp] -> InKernelGen ())
-> ImpM
     ExplicitMemory
     KernelEnv
     KernelOp
     ([(VName, VName)], VName, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
glob_subhistos [VName]
local_subhistos, VName
hist_H_chk, [Exp] -> InKernelGen ()
do_op)

    -- Find index of local subhistograms updated by this thread.  We
    -- try to ensure, as much as possible, that threads in the same
    -- warp use different subhistograms, to avoid conflicts.
    Exp
thread_local_subhisto_i <-
      String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"thread_local_subhisto_i" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
      KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
num_subhistos_per_group

    let onSlugs :: (SegHistSlug
 -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
onSlugs SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ()
f = [(SegHistSlug, ([(VName, VName)], VName, [Exp] -> InKernelGen ()))]
-> ((SegHistSlug,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
-> [(SegHistSlug,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()))]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegHistSlug]
slugs [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
histograms) (((SegHistSlug, ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegHistSlug,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, ([(VName, VName)]
dests, VName
hist_H_chk, [Exp] -> InKernelGen ()
_)) -> do
          let histo_dims :: [Exp]
histo_dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
hist_H_chk SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
                           Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug))
          Exp
histo_size <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"histo_size" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
histo_dims
          SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (VName -> Exp
Imp.vi32 VName
hist_H_chk) [Exp]
histo_dims Exp
histo_size

    let onAllHistograms :: (VName
 -> VName
 -> HistOp ExplicitMemory
 -> SubExp
 -> Exp
 -> Exp
 -> [Exp]
 -> [Exp]
 -> InKernelGen ())
-> InKernelGen ()
onAllHistograms VName
-> VName
-> HistOp ExplicitMemory
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ()
f =
          (SegHistSlug
 -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
  -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
 -> InKernelGen ())
-> (SegHistSlug
    -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests Exp
hist_H_chk [Exp]
histo_dims Exp
histo_size -> do
            let group_hists_size :: Exp
group_hists_size = Exp
num_subhistos_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
histo_size
            Exp
init_per_thread <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"init_per_thread" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                               Exp
group_hists_size Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp`
                               KernelConstants -> Exp
kernelGroupSize KernelConstants
constants

            [((VName, VName), SubExp)]
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, VName)] -> [SubExp] -> [((VName, VName), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (HistOp ExplicitMemory -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral (HistOp ExplicitMemory -> [SubExp])
-> HistOp ExplicitMemory -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug)) ((((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              \((VName
dest_global, VName
dest_local), SubExp
ne) ->
                String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"local_i" Exp
init_per_thread ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
                  Exp
j <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"j" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                       Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
                       KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
                  Exp
j_offset <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"j_offset" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                              Exp
num_subhistos_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
histo_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
gid_in_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
j

                  Exp
local_subhisto_i <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"local_subhisto_i" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
j Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
histo_size
                  let local_bucket_is :: [Exp]
local_bucket_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
histo_dims (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ Exp
j Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
histo_size
                      global_bucket_is :: [Exp]
global_bucket_is = [Exp] -> Exp
forall a. [a] -> a
head [Exp]
local_bucket_is Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:
                                         [Exp] -> [Exp]
forall a. [a] -> [a]
tail [Exp]
local_bucket_is
                  Exp
global_subhisto_i <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"global_subhisto_i" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
j_offset Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
histo_size

                  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
j Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
group_hists_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                    VName
-> VName
-> HistOp ExplicitMemory
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ()
f VName
dest_local VName
dest_global (SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug) SubExp
ne
                    Exp
local_subhisto_i Exp
global_subhisto_i
                    [Exp]
local_bucket_is [Exp]
global_bucket_is

    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"initialize histograms in local memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      (VName
 -> VName
 -> HistOp ExplicitMemory
 -> SubExp
 -> Exp
 -> Exp
 -> [Exp]
 -> [Exp]
 -> InKernelGen ())
-> InKernelGen ()
onAllHistograms ((VName
  -> VName
  -> HistOp ExplicitMemory
  -> SubExp
  -> Exp
  -> Exp
  -> [Exp]
  -> [Exp]
  -> InKernelGen ())
 -> InKernelGen ())
-> (VName
    -> VName
    -> HistOp ExplicitMemory
    -> SubExp
    -> Exp
    -> Exp
    -> [Exp]
    -> [Exp]
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
dest_local VName
dest_global HistOp ExplicitMemory
op SubExp
ne Exp
local_subhisto_i Exp
global_subhisto_i [Exp]
local_bucket_is [Exp]
global_bucket_is ->
      String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"First subhistogram is initialised from global memory; others with neutral element." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      let global_is :: [Exp]
global_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
segment_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp
0] [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
global_bucket_is
          local_is :: [Exp]
local_is = Exp
local_subhisto_i Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
local_bucket_is
      Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
global_subhisto_i Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0)
        (VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest_local [Exp]
local_is (VName -> SubExp
Var VName
dest_global) [Exp]
global_is)
        (Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp ExplicitMemory
op) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest_local ([Exp]
local_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
is) SubExp
ne [])

    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

    Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop Exp
pgtid_in_segment Exp
threads_per_segment Exp
segment_size' ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
ie -> do
      VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
i_in_segment Exp
ie

      -- We execute the bucket function once and update each histogram
      -- serially.  This also involves writing to the mapout arrays if
      -- this is the first chunk.

      Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do

        let ([SubExp]
red_res, [SubExp]
map_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT (MemInfo SubExp NoUniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                             (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
kbody
            ([SubExp]
buckets, [SubExp]
vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [SubExp]
red_res
            perOp :: [SubExp] -> [[SubExp]]
perOp = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [SubExp] -> [[SubExp]])
-> [Int] -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> [VName]
forall lore. HistOp lore -> [VName]
histDest (HistOp ExplicitMemory -> [VName])
-> (SegHistSlug -> HistOp ExplicitMemory) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp ExplicitMemory
slugOp) [SegHistSlug]
slugs

        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
chk_i Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(PatElemT (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes [SubExp]
map_res) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, SubExp
se) ->
          VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe)
          ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
space_is) SubExp
se []

        [(HistOp ExplicitMemory,
  ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
  [SubExp])]
-> ((HistOp ExplicitMemory,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
     [SubExp])
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp ExplicitMemory]
-> [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
-> [SubExp]
-> [[SubExp]]
-> [(HistOp ExplicitMemory,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
     [SubExp])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 ((SegHistSlug -> HistOp ExplicitMemory)
-> [SegHistSlug] -> [HistOp ExplicitMemory]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp ExplicitMemory
slugOp [SegHistSlug]
slugs) [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
histograms [SubExp]
buckets ([SubExp] -> [[SubExp]]
perOp [SubExp]
vs)) (((HistOp ExplicitMemory,
   ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
   [SubExp])
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp ExplicitMemory,
     ([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
     [SubExp])
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          \(HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT ExplicitMemory
lam,
            ([(VName, VName)]
_, VName
hist_H_chk, [Exp] -> InKernelGen ()
do_op), SubExp
bucket, [SubExp]
vs') -> do

            let chk_beg :: Exp
chk_beg = Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
Imp.vi32 VName
hist_H_chk
                bucket' :: Exp
bucket' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
bucket
                dest_w' :: Exp
dest_w' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
dest_w
                bucket_in_bounds :: Exp
bucket_in_bounds = Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
dest_w' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                                   Exp
chk_beg Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                                   Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. (Exp
chk_beg Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
Imp.vi32 VName
hist_H_chk)
                bucket_is :: [Exp]
bucket_is = [Exp
thread_local_subhisto_i, Exp
bucket' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
chk_beg]
                vs_params :: [Param (MemInfo SubExp NoUniqueness MemBind)]
vs_params = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ LambdaT ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT ExplicitMemory
lam

            String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              [LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams ([LParam ExplicitMemory] -> InKernelGen ())
-> [LParam ExplicitMemory] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT ExplicitMemory
lam
              Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is -> do
                [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
vs_params [SubExp]
vs') (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
v) ->
                  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
v [Exp]
is
                [Exp] -> InKernelGen ()
do_op ([Exp]
bucket_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)

    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal

    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Compact the multiple local memory subhistograms to result in global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      (SegHistSlug
 -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
  -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
 -> InKernelGen ())
-> (SegHistSlug
    -> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests Exp
hist_H_chk [Exp]
histo_dims Exp
histo_size -> do
      Exp
bins_per_thread <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"init_per_thread" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                         Exp
histo_size Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` KernelConstants -> Exp
kernelGroupSize KernelConstants
constants

      VName
trunc_H <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"trunc_H" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
                 BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) Exp
hist_H_chk (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                 PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug)) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
-
                 Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* [Exp] -> Exp
forall a. [a] -> a
head [Exp]
histo_dims
      let trunc_histo_dims :: [Exp]
trunc_histo_dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
trunc_H SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
                             Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug))
      Exp
trunc_histo_size <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"histo_size" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
trunc_histo_dims

      String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"local_i" Exp
bins_per_thread ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
        Exp
j <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"j" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
             Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
j Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
trunc_histo_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          -- We are responsible for compacting the flat bin 'j', which
          -- we immediately unflatten.
          let local_bucket_is :: [Exp]
local_bucket_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
histo_dims Exp
j
              global_bucket_is :: [Exp]
global_bucket_is = [Exp] -> Exp
forall a. [a] -> a
head [Exp]
local_bucket_is Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:
                                 [Exp] -> [Exp]
forall a. [a] -> [a]
tail [Exp]
local_bucket_is
          [LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams ([LParam ExplicitMemory] -> InKernelGen ())
-> [LParam ExplicitMemory] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT ExplicitMemory -> [LParam ExplicitMemory])
-> LambdaT ExplicitMemory -> [LParam ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp (HistOp ExplicitMemory -> LambdaT ExplicitMemory)
-> HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug
          let ([VName]
global_dests, [VName]
local_dests) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, VName)]
dests
              ([Param (MemInfo SubExp NoUniqueness MemBind)]
xparams, [Param (MemInfo SubExp NoUniqueness MemBind)]
yparams) = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
local_dests) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$
                                   LambdaT ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT ExplicitMemory -> [LParam ExplicitMemory])
-> LambdaT ExplicitMemory -> [LParam ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp (HistOp ExplicitMemory -> LambdaT ExplicitMemory)
-> HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Read values from subhistogram 0." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
xparams [VName]
local_dests) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
xp, VName
subhisto) ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix
            (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
xp) []
            (VName -> SubExp
Var VName
subhisto) (Exp
0Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:[Exp]
local_bucket_is)

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Accumulate based on values in other subhistograms." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"subhisto_id" (Exp
num_subhistos_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1) ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
subhisto_id -> do
              [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
yparams [VName]
local_dests) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
yp, VName
subhisto) ->
                VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix
                (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
yp) []
                (VName -> SubExp
Var VName
subhisto) (Exp
subhisto_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1 Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
local_bucket_is)
              [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
xparams (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT ExplicitMemory -> Body ExplicitMemory)
-> LambdaT ExplicitMemory -> Body ExplicitMemory
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp (HistOp ExplicitMemory -> LambdaT ExplicitMemory)
-> HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp ExplicitMemory
slugOp SegHistSlug
slug

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Put final bucket value in global memory." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
            let global_is :: [Exp]
global_is =
                  (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
segment_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++
                  [Exp
group_id Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment] [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++
                  [Exp]
global_bucket_is
            [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
xparams [VName]
global_dests) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
xp, VName
global_dest) ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
global_dest [Exp]
global_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
xp) []

histKernelLocal :: VName -> Count NumGroups Imp.Exp
                -> [PatElem ExplicitMemory]
                -> Count NumGroups SubExp -> Count GroupSize SubExp
                -> SegSpace
                -> Imp.Exp
                -> [SegHistSlug]
                -> KernelBody ExplicitMemory
                -> CallKernelGen ()
histKernelLocal :: VName
-> Count NumGroups Exp
-> [PatElem ExplicitMemory]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Exp
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> CallKernelGen ()
histKernelLocal VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment [PatElem ExplicitMemory]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space Exp
hist_S [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody = do
  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
  let num_subhistos_per_group :: Exp
num_subhistos_per_group = VName -> PrimType -> Exp
Imp.var VName
num_subhistos_per_group_var PrimType
int32

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of local subhistograms per group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
num_subhistos_per_group

  InitLocalHistograms
init_histograms <-
    VName
-> Count NumGroups Exp
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment SegSpace
space [SegHistSlug]
slugs

  String -> Exp -> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"chk_i" Exp
hist_S ((Exp -> CallKernelGen ()) -> CallKernelGen ())
-> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
chk_i ->
    VName
-> Count NumGroups Exp
-> [PatElem ExplicitMemory]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> InitLocalHistograms
-> Exp
-> Exp
-> CallKernelGen ()
histKernelLocalPass
    VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment [PatElem ExplicitMemory]
map_pes Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' SegSpace
space [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody
    InitLocalHistograms
init_histograms Exp
hist_S Exp
chk_i

-- | The maximum number of passes we are willing to accept for this
-- kind of atomic update.
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
  case SegHistSlug -> AtomicUpdate ExplicitMemory KernelEnv
slugAtomicUpdate SegHistSlug
slug of
    AtomicPrim DoAtomicUpdate ExplicitMemory KernelEnv
_ -> Int
3
    AtomicCAS DoAtomicUpdate ExplicitMemory KernelEnv
_  -> Int
4
    AtomicLocking Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
_ -> Int
6

localMemoryCase :: [PatElem ExplicitMemory]
                -> Imp.Exp
                -> SegSpace
                -> Imp.Exp -> Imp.Exp -> Imp.Exp -> Imp.Exp
                -> [SegHistSlug]
                -> KernelBody ExplicitMemory
                -> CallKernelGen (Imp.Exp, CallKernelGen ())
localMemoryCase :: [PatElem ExplicitMemory]
-> Exp
-> SegSpace
-> Exp
-> Exp
-> Exp
-> Exp
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> CallKernelGen (Exp, CallKernelGen ())
localMemoryCase [PatElem ExplicitMemory]
map_pes Exp
hist_T SegSpace
space Exp
hist_H Exp
hist_el_size Exp
hist_N Exp
_ [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody = do
  let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
      segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
      segmented :: Bool
segmented = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims

  VName
hist_L <- String -> PrimType -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"hist_L" PrimType
int32
  HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax VName
hist_L SizeClass
Imp.SizeLocalMemory

  VName
max_group_size <- String -> PrimType -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"max_group_size" PrimType
int32
  HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax VName
max_group_size SizeClass
Imp.SizeGroup
  let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
max_group_size
  Count NumGroups SubExp
num_groups <- (VName -> Count NumGroups SubExp)
-> ImpM ExplicitMemory HostEnv HostOp VName
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (VName -> SubExp) -> VName -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) (ImpM ExplicitMemory HostEnv HostOp VName
 -> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups SubExp))
-> ImpM ExplicitMemory HostEnv HostOp VName
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$ String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_groups" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
                Exp
hist_T Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
  let num_groups' :: Count NumGroups Exp
num_groups' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> Count NumGroups SubExp -> Count NumGroups Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize Exp
group_size' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> Count GroupSize SubExp -> Count GroupSize Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count GroupSize SubExp
group_size

  let r64 :: PrimExp v -> PrimExp v
r64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64)
      t64 :: PrimExp v -> PrimExp v
t64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int32)
      i32_to_i64 :: PrimExp v -> PrimExp v
i32_to_i64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64)
      i64_to_i32 :: PrimExp v -> PrimExp v
i64_to_i32 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int64 IntType
Int32)

  -- M approximation.
  Exp
hist_m' <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_m_prime" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
             Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32)
                  (VName -> Exp
Imp.vi32 VName
hist_L Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
hist_el_size)
                  (Exp
hist_N Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups'))
             Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_H

  let hist_B :: Exp
hist_B = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

  -- M in the paper, but not adjusted for asymptotic efficiency.
  Exp
hist_M0 <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_M0" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
             BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
             BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 Exp
hist_m') Exp
hist_B

  -- Minimal sequential chunking factor.
  let q_small :: Exp
q_small = Exp
2

  -- The number of segments/histograms produced..
  Exp
hist_Nout <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_Nout" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
segment_dims

  Exp
hist_Nin <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_Nin" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> SubExp -> Exp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes

  -- Maximum M for work efficiency.
  Exp
work_asymp_M_max <-
    if Bool
segmented then do

      Exp
hist_T_hist_min <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_T_hist_min" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
                         Exp -> Exp
forall v. PrimExp v -> PrimExp v
i64_to_i32 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
                         BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int64)
                         (Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32_to_i64 Exp
hist_Nin Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32_to_i64 Exp
hist_Nout) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32_to_i64 Exp
hist_T)
                         Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp`
                         Exp -> Exp
forall v. PrimExp v -> PrimExp v
i32_to_i64 Exp
hist_Nout

      -- Number of groups, rounded up.
      let r :: Exp
r = Exp
hist_T_hist_min Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_B

      String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"work_asymp_M_max" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
hist_Nin Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` (Exp
r Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H)

    else String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"work_asymp_M_max" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
         (Exp
hist_Nout Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_N) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot`
         ((Exp
q_small Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H)
          Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> Exp
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs)

  -- Number of subhistograms per result histogram.
  VName
hist_M <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"hist_M" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
            BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) Exp
hist_M0 Exp
work_asymp_M_max

  -- hist_M may be zero (which we'll check for below), but we need it
  -- for some divisions first, so crudely make a nonzero form.
  let hist_M_nonzero :: Exp
hist_M_nonzero = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_M

  -- "Cooperation factor" - the number of threads cooperatively
  -- working on the same (sub)histogram.
  Exp
hist_C <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_C" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
            Exp
hist_B Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_M_nonzero

  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local hist_M0" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_M0
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local work asymp M max" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
work_asymp_M_max
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local C" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_C
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local B" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_B
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local M" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_M
  Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local memory needed" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
    Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Exp
hist_H Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_el_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
Imp.vi32 VName
hist_M

  -- local_mem_needed is what we need to keep a single bucket in local
  -- memory - this is an absolute minimum.  We can fit anything else
  -- by doing multiple passes, although more than a few is
  -- (heuristically) not efficient.
  Exp
local_mem_needed <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"local_mem_needed" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
hist_el_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
Imp.vi32 VName
hist_M
  Exp
hist_S <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_S" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ (Exp
hist_H Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
local_mem_needed) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` VName -> Exp
Imp.vi32 VName
hist_L
  let max_S :: Exp
max_S = case KernelBody ExplicitMemory -> Passage
bodyPassage KernelBody ExplicitMemory
kbody of
                Passage
MustBeSinglePass -> Exp
1
                Passage
MayBeMultiPass -> Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Exp) -> Int -> Exp
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs

  -- We only use local memory if the number of updates per histogram
  -- at least matches the histogram size, as otherwise it is not
  -- asymptotically efficient.  This mostly matters for the segmented
  -- case.
  let pick_local :: Exp
pick_local =
        Exp
hist_Nin Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>=. Exp
hist_H
        Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. (Exp
local_mem_needed Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. VName -> Exp
Imp.vi32 VName
hist_L)
        Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. (Exp
hist_S Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
max_S)
        Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
hist_C Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
hist_B
        Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. VName -> Exp
Imp.vi32 VName
hist_M Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. Exp
0

      groups_per_segment :: Count NumGroups Exp
groups_per_segment
        | Bool
segmented = Count NumGroups Exp
num_groups' Count NumGroups Exp -> Count NumGroups Exp -> Count NumGroups Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp -> Count NumGroups Exp
forall u e. e -> Count u e
Imp.Count Exp
hist_Nout
        | Bool
otherwise = Count NumGroups Exp
num_groups'

      run :: CallKernelGen ()
run = do
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using local memory" Maybe Exp
forall a. Maybe a
Nothing
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_H
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_M
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_C
        Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_S
        Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Groups per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment
        VName
-> Count NumGroups Exp
-> [PatElem ExplicitMemory]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Exp
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> CallKernelGen ()
histKernelLocal VName
hist_M Count NumGroups Exp
groups_per_segment [PatElem ExplicitMemory]
map_pes
          Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space Exp
hist_S [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody

  (Exp, CallKernelGen ()) -> CallKernelGen (Exp, CallKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
pick_local, CallKernelGen ()
run)

-- Most of this function is not the histogram part itself, but rather
-- figuring out whether to use a local or global memory strategy, as
-- well as collapsing the subhistograms produced (which are always in
-- global memory, but their number may vary).
compileSegHist :: Pattern ExplicitMemory
               -> Count NumGroups SubExp -> Count GroupSize SubExp
               -> SegSpace
               -> [HistOp ExplicitMemory]
               -> KernelBody ExplicitMemory
               -> CallKernelGen ()
compileSegHist :: Pattern ExplicitMemory
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp ExplicitMemory]
-> KernelBody ExplicitMemory
-> CallKernelGen ()
compileSegHist (Pattern [PatElem ExplicitMemory]
_ [PatElem ExplicitMemory]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp ExplicitMemory]
ops KernelBody ExplicitMemory
kbody = do
  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM ExplicitMemory HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size

  [Exp]
dims <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp])
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

  let num_red_res :: Int
num_red_res = [HistOp ExplicitMemory] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp ExplicitMemory]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp ExplicitMemory -> Int) -> [HistOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp ExplicitMemory -> [SubExp])
-> HistOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp ExplicitMemory]
ops)
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
all_red_pes, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes) = Int
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
    [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
pes
      segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims

  ([Count Bytes Exp]
op_hs, [Count Bytes Exp]
op_seg_hs, [SegHistSlug]
slugs) <- [(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
-> ([Count Bytes Exp], [Count Bytes Exp], [SegHistSlug])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
 -> ([Count Bytes Exp], [Count Bytes Exp], [SegHistSlug]))
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     [(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     ([Count Bytes Exp], [Count Bytes Exp], [SegHistSlug])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp ExplicitMemory
 -> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug))
-> [HistOp ExplicitMemory]
-> ImpM
     ExplicitMemory
     HostEnv
     HostOp
     [(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegSpace
-> HistOp ExplicitMemory
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
computeHistoUsage SegSpace
space) [HistOp ExplicitMemory]
ops
  Exp
h <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"h" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
Imp.unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes Exp]
op_hs
  Exp
seg_h <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"seg_h" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
Imp.unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes Exp]
op_seg_hs

  -- Check for emptyness to avoid division-by-zero.
  Exp -> CallKernelGen () -> CallKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless (Exp
seg_h Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do

    -- Maximum group size (or actual, in this case).
    let hist_B :: Exp
hist_B = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'

    -- Size of a histogram.
    Exp
hist_H <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (HistOp ExplicitMemory -> Exp) -> [HistOp ExplicitMemory] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> (HistOp ExplicitMemory -> SubExp)
-> HistOp ExplicitMemory
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth) [HistOp ExplicitMemory]
ops

    -- Size of a single histogram element.  Actually the weighted
    -- average of histogram elements in cases where we have more than
    -- one histogram operation, plus any locks.
    let lockSize :: SegHistSlug -> Maybe a
lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate ExplicitMemory KernelEnv
slugAtomicUpdate SegHistSlug
slug of
                          AtomicLocking{} -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
int32
                          AtomicUpdate ExplicitMemory KernelEnv
_               -> Maybe a
forall a. Maybe a
Nothing
    Exp
hist_el_size <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_el_size" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(+) (Exp
h Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quotRoundingUp` Exp
hist_H) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
                    (SegHistSlug -> Maybe Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SegHistSlug -> Maybe Exp
forall a. Num a => SegHistSlug -> Maybe a
lockSize [SegHistSlug]
slugs

    -- Input elements contributing to each histogram.
    Exp
hist_N <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_N" Exp
segment_size

    -- Compute RF as the average RF over all the histograms.
    Exp
hist_RF <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_RF" (Exp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
               [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32(SubExp -> Exp) -> (SegHistSlug -> SubExp) -> SegHistSlug -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histRaceFactor (HistOp ExplicitMemory -> SubExp)
-> (SegHistSlug -> HistOp ExplicitMemory) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp ExplicitMemory
slugOp) [SegHistSlug]
slugs)
               Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot`
               [SegHistSlug] -> Exp
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs

    let hist_T :: Exp
hist_T = Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegHist" Maybe Exp
forall a. Maybe a
Nothing
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of threads (T)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_T
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Desired group size (B)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_B
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_H
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Input elements per histogram (N)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_N
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram element size (el_size)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_el_size
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race factor (RF)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_RF
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
h
    Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms times segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
seg_h

    (Exp
use_local_memory, CallKernelGen ()
run_in_local_memory) <-
      [PatElem ExplicitMemory]
-> Exp
-> SegSpace
-> Exp
-> Exp
-> Exp
-> Exp
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> CallKernelGen (Exp, CallKernelGen ())
localMemoryCase [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes Exp
hist_T SegSpace
space Exp
hist_H Exp
hist_el_size Exp
hist_N Exp
hist_RF [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody

    Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
use_local_memory CallKernelGen ()
run_in_local_memory (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [PatElem ExplicitMemory]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody ExplicitMemory
-> CallKernelGen ()
histKernelGlobal [PatElem ExplicitMemory]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody ExplicitMemory
kbody

    let pes_per_op :: [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
pes_per_op = [Int]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp ExplicitMemory -> Int) -> [HistOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp ExplicitMemory -> [VName])
-> HistOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp ExplicitMemory]
ops) [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
all_red_pes

    [(SegHistSlug, [PatElemT (MemInfo SubExp NoUniqueness MemBind)],
  HistOp ExplicitMemory)]
-> ((SegHistSlug, [PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     HistOp ExplicitMemory)
    -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
-> [HistOp ExplicitMemory]
-> [(SegHistSlug, [PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     HistOp ExplicitMemory)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [[PatElemT (MemInfo SubExp NoUniqueness MemBind)]]
pes_per_op [HistOp ExplicitMemory]
ops) (((SegHistSlug, [PatElemT (MemInfo SubExp NoUniqueness MemBind)],
   HistOp ExplicitMemory)
  -> CallKernelGen ())
 -> CallKernelGen ())
-> ((SegHistSlug, [PatElemT (MemInfo SubExp NoUniqueness MemBind)],
     HistOp ExplicitMemory)
    -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
red_pes, HistOp ExplicitMemory
op) -> do
      let num_histos :: VName
num_histos = SegHistSlug -> VName
slugNumSubhistos SegHistSlug
slug
          subhistos :: [VName]
subhistos = (SubhistosInfo -> VName) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray ([SubhistosInfo] -> [VName]) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug

      let unitHistoCase :: CallKernelGen ()
unitHistoCase =
            -- This is OK because the memory blocks are at least as
            -- large as the ones we are supposed to use for the result.
            [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
red_pes [VName]
subhistos) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
  -> CallKernelGen ())
 -> CallKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
    -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
subhisto) -> do
              VName
pe_mem <- MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> VName)
-> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
-> ImpM ExplicitMemory HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                        VName -> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe)
              VName
subhisto_mem <- MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> VName)
-> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
-> ImpM ExplicitMemory HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                              VName -> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
subhisto
              Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
pe_mem VName
subhisto_mem (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"

      Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (VName -> PrimType -> Exp
Imp.var VName
num_histos PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) CallKernelGen ()
unitHistoCase (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        -- For the segmented reduction, we keep the segment dimensions
        -- unchanged.  To this, we add two dimensions: one over the number
        -- of buckets, and one over the number of subhistograms.  This
        -- inner dimension is the one that is collapsed in the reduction.
        let num_buckets :: SubExp
num_buckets = HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp ExplicitMemory
op

        VName
bucket_id <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id"
        VName
subhistogram_id <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"subhistogram_id"
        [VName]
vector_ids <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ImpM ExplicitMemory HostEnv HostOp VName
-> SubExp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. a -> b -> a
const (ImpM ExplicitMemory HostEnv HostOp VName
 -> SubExp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> ImpM ExplicitMemory HostEnv HostOp VName
-> SubExp
-> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"vector_id") ([SubExp] -> ImpM ExplicitMemory HostEnv HostOp [VName])
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$
                      Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp ExplicitMemory
op

        VName
flat_gtid <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_gtid"

        let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegVirt
            segred_space :: SegSpace
segred_space =
              VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
              [(VName, SubExp)]
segment_dims [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++
              [(VName
bucket_id, SubExp
num_buckets)] [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++
              [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp ExplicitMemory
op) [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++
              [(VName
subhistogram_id, VName -> SubExp
Var VName
num_histos)]

        let segred_op :: SegRedOp ExplicitMemory
segred_op = Commutativity
-> LambdaT ExplicitMemory
-> [SubExp]
-> Shape
-> SegRedOp ExplicitMemory
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegRedOp lore
SegRedOp Commutativity
Commutative (HistOp ExplicitMemory -> LambdaT ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp HistOp ExplicitMemory
op) (HistOp ExplicitMemory -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp ExplicitMemory
op) Shape
forall a. Monoid a => a
mempty
        Pattern ExplicitMemory
-> SegLevel
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
red_pes) SegLevel
lvl SegSpace
segred_space [SegRedOp ExplicitMemory
segred_op] (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [Exp])] -> InKernelGen ()
red_cont ->
          [(SubExp, [Exp])] -> InKernelGen ()
red_cont ([(SubExp, [Exp])] -> InKernelGen ())
-> [(SubExp, [Exp])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ ((VName -> (SubExp, [Exp])) -> [VName] -> [(SubExp, [Exp])])
-> [VName] -> (VName -> (SubExp, [Exp])) -> [(SubExp, [Exp])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [Exp])) -> [VName] -> [(SubExp, [Exp])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos ((VName -> (SubExp, [Exp])) -> [(SubExp, [Exp])])
-> (VName -> (SubExp, [Exp])) -> [(SubExp, [Exp])]
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
            (VName -> SubExp
Var VName
subhisto, (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 ([VName] -> [Exp]) -> [VName] -> [Exp]
forall a b. (a -> b) -> a -> b
$
              ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id, VName
bucket_id] [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids)

  where segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space