{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

-- | Our compilation strategy for 'SegHist' is based around avoiding
-- bin conflicts.  We do this by splitting the input into chunks, and
-- for each chunk computing a single subhistogram.  Then we combine
-- the subhistograms using an ordinary segmented reduction ('SegRed').
--
-- There are some branches around to efficiently handle the case where
-- we use only a single subhistogram (because it's large), so that we
-- respect the asymptotics, and do not copy the destination array.
--
-- We also use a heuristic strategy for computing subhistograms in
-- local memory when possible.  Given:
--
-- H: total size of histograms in bytes, including any lock arrays.
--
-- G: group size
--
-- T: number of bytes of local memory each thread can be given without
-- impacting occupancy (determined experimentally, e.g. 32).
--
-- LMAX: maximum amount of local memory per workgroup (hard limit).
--
-- We wish to compute:
--
-- COOP: cooperation level (number of threads per subhistogram)
--
-- LH: number of local memory subhistograms
--
-- We do this as:
--
-- COOP = ceil(H / T)
-- LH = ceil((G*T)/H)
-- if COOP <= G && H <= LMAX then
--   use local memory
-- else
--   use global memory
module Futhark.CodeGen.ImpGen.GPU.SegHist (compileSegHist) where

import Control.Monad.Except
import Data.List (foldl', genericLength, zip4, zip6)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegRed (compileSegRed')
import Futhark.Construct (fullSliceNum)
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Pass.ExplicitAllocations ()
import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

data SubhistosInfo = SubhistosInfo
  { SubhistosInfo -> VName
subhistosArray :: VName,
    SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
  }

data SegHistSlug = SegHistSlug
  { SegHistSlug -> HistOp GPUMem
slugOp :: HistOp GPUMem,
    SegHistSlug -> TV Int64
slugNumSubhistos :: TV Int64,
    SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo],
    SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate :: AtomicUpdate GPUMem KernelEnv
  }

histoSpaceUsage ::
  HistOp GPUMem ->
  Imp.Count Imp.Bytes (Imp.TExp Int64)
histoSpaceUsage :: HistOp GPUMem -> Count Bytes (TExp Int64)
histoSpaceUsage HistOp GPUMem
op =
  [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
    (Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map
      ( Type -> Count Bytes (TExp Int64)
typeSize
          (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histWidth HistOp GPUMem
op)
          (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)
      )
      ([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (LambdaT GPUMem -> [Type]) -> LambdaT GPUMem -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op

-- | Figure out how much memory is needed per histogram, both
-- segmented and unsegmented,, and compute some other auxiliary
-- information.
computeHistoUsage ::
  SegSpace ->
  HistOp GPUMem ->
  CallKernelGen
    ( Imp.Count Imp.Bytes (Imp.TExp Int64),
      Imp.Count Imp.Bytes (Imp.TExp Int64),
      SegHistSlug
    )
computeHistoUsage :: SegSpace
-> HistOp GPUMem
-> CallKernelGen
     (Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
computeHistoUsage SegSpace
space HistOp GPUMem
op = do
  let segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_segments :: Int
num_segments = [(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims

  -- Create names for the intermediate array memory blocks,
  -- memory block sizes, arrays, and number of subhistograms.
  TV Int64
num_subhistos <- String -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"num_subhistos" PrimType
int32
  [SubhistosInfo]
subhisto_infos <- [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op) (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op)) (((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
 -> ImpM GPUMem HostEnv HostOp [SubhistosInfo])
-> ((VName, SubExp) -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> ImpM GPUMem HostEnv HostOp [SubhistosInfo]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) -> do
    Type
dest_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dest
    MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest

    VName
subhistos_mem <-
      String -> Space -> ImpM GPUMem HostEnv HostOp VName
forall rep r op. String -> Space -> ImpM rep r op VName
sDeclareMem (VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos_mem") (String -> Space
Space String
"device")

    let subhistos_shape :: Shape
subhistos_shape =
          [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_subhistos])
            Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Int -> Shape -> Shape
forall a. ArrayShape a => Int -> a -> a
stripDims Int
num_segments (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
        subhistos_membind :: MemBind
subhistos_membind =
          VName -> IxFun -> MemBind
ArrayIn VName
subhistos_mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
            Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
              (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
    VName
subhistos <-
      String
-> PrimType -> Shape -> MemBind -> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
sArray
        (VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos")
        (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
        Shape
subhistos_shape
        MemBind
subhistos_membind

    SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo
forall (m :: * -> *) a. Monad m => a -> m a
return (SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo)
-> SubhistosInfo -> ImpM GPUMem HostEnv HostOp SubhistosInfo
forall a b. (a -> b) -> a -> b
$
      VName -> CallKernelGen () -> SubhistosInfo
SubhistosInfo VName
subhistos (CallKernelGen () -> SubhistosInfo)
-> CallKernelGen () -> SubhistosInfo
forall a b. (a -> b) -> a -> b
$ do
        let unitHistoCase :: CallKernelGen ()
unitHistoCase =
              Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
                VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                  String -> Space
Space String
"device"

            multiHistoCase :: CallKernelGen ()
multiHistoCase = do
              let num_elems :: TExp Int64
num_elems =
                    (TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
(*) (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos) ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
                      (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dest_t

              let subhistos_mem_size :: Count Bytes (TExp Int64)
subhistos_mem_size =
                    TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                      Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements TExp Int64
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`Imp.withElemType` Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)

              VName -> Count Bytes (TExp Int64) -> Space -> CallKernelGen ()
forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
subhistos_mem Count Bytes (TExp Int64)
subhistos_mem_size (Space -> CallKernelGen ()) -> Space -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"
              VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
              Type
subhistos_t <- VName -> ImpM GPUMem HostEnv HostOp Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
subhistos
              let slice :: Slice (TExp Int64)
slice =
                    [TExp Int64] -> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) ([DimIndex (TExp Int64)] -> Slice (TExp Int64))
-> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                      ((VName, SubExp) -> DimIndex (TExp Int64))
-> [(VName, SubExp)] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. Num d => d -> d -> DimIndex d
unitSlice TExp Int64
0 (TExp Int64 -> DimIndex (TExp Int64))
-> ((VName, SubExp) -> TExp Int64)
-> (VName, SubExp)
-> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
                        [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
0]
              VName -> Slice (TExp Int64) -> SubExp -> CallKernelGen ()
forall rep r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate VName
subhistos Slice (TExp Int64)
slice (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest

        TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase

  let h :: Count Bytes (TExp Int64)
h = HistOp GPUMem -> Count Bytes (TExp Int64)
histoSpaceUsage HistOp GPUMem
op
      segmented_h :: Count Bytes (TExp Int64)
segmented_h = Count Bytes (TExp Int64)
h Count Bytes (TExp Int64)
-> Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall a. Num a => a -> a -> a
* [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> Count Bytes (TExp Int64))
-> [SubExp] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) ([SubExp] -> [Count Bytes (TExp Int64)])
-> [SubExp] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv

  (Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
-> CallKernelGen
     (Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Count Bytes (TExp Int64)
h,
      Count Bytes (TExp Int64)
segmented_h,
      HistOp GPUMem
-> TV Int64
-> [SubhistosInfo]
-> AtomicUpdate GPUMem KernelEnv
-> SegHistSlug
SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_infos (AtomicUpdate GPUMem KernelEnv -> SegHistSlug)
-> AtomicUpdate GPUMem KernelEnv -> SegHistSlug
forall a b. (a -> b) -> a -> b
$
        AtomicBinOp -> LambdaT GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics (LambdaT GPUMem -> AtomicUpdate GPUMem KernelEnv)
-> LambdaT GPUMem -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op
    )

prepareAtomicUpdateGlobal ::
  Maybe Locking ->
  [VName] ->
  SegHistSlug ->
  CallKernelGen
    ( Maybe Locking,
      [Imp.TExp Int64] -> InKernelGen ()
    )
prepareAtomicUpdateGlobal :: Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug =
  -- We need a separate lock array if the operators are not all of a
  -- particularly simple form that permits pure atomic operations.
  case (Maybe Locking
l, SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
    (Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
    (Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
    (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)
    (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
      -- The number of locks used here is too low, but since we are
      -- currently forced to inline a huge list, I'm keeping it down
      -- for now.  Some quick experiments suggested that it has little
      -- impact anyway (maybe the locking case is just too slow).
      --
      -- A fun solution would also be to use a simple hashing
      -- algorithm to ensure good distribution of locks.
      let num_locks :: Int
num_locks = Int
100151
          dims :: [TExp Int64]
dims =
            (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
              Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
                [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [ TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug),
                     HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histWidth (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)
                   ]
      VName
locks <-
        String
-> Space
-> PrimType
-> ArrayContents
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"hist_locks" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM GPUMem HostEnv HostOp VName)
-> ArrayContents -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
          Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
      let l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TExp Int64 -> [TExp Int64]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64 -> [TExp Int64])
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TExp Int64 -> TExp Int64)
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
      (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)

-- | Some kernel bodies are not safe (or efficient) to execute
-- multiple times.
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
(Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool) -> Eq Passage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c== :: Passage -> Passage -> Bool
Eq, Eq Passage
Eq Passage
-> (Passage -> Passage -> Ordering)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Passage)
-> (Passage -> Passage -> Passage)
-> Ord Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmax :: Passage -> Passage -> Passage
>= :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c< :: Passage -> Passage -> Bool
compare :: Passage -> Passage -> Ordering
$ccompare :: Passage -> Passage -> Ordering
$cp1Ord :: Eq Passage
Ord)

bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage :: KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody
  | Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== KernelBody (Aliases GPUMem) -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (AliasTable -> KernelBody GPUMem -> KernelBody (Aliases GPUMem)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
forall a. Monoid a => a
mempty KernelBody GPUMem
kbody) =
    Passage
MayBeMultiPass
  | Bool
otherwise =
    Passage
MustBeSinglePass

prepareIntermediateArraysGlobal ::
  Passage ->
  Imp.TExp Int32 ->
  Imp.TExp Int64 ->
  [SegHistSlug] ->
  CallKernelGen
    ( Imp.TExp Int32,
      [[Imp.TExp Int64] -> InKernelGen ()]
    )
prepareIntermediateArraysGlobal :: Passage
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage TExp Int32
hist_T TExp Int64
hist_N [SegHistSlug]
slugs = do
  -- The paper formulae assume there is only one histogram, but in our
  -- implementation there can be multiple that have been horisontally
  -- fused.  We do a bit of trickery with summings and averages to
  -- pretend there is really only one.  For the case of a single
  -- histogram, the actual calculations should be the same as in the
  -- paper.

  -- The sum of all Hs.
  TExp Int64
hist_H <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_H" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TExp Int64) -> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histWidth (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs

  TExp Double
hist_RF <-
    String -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_RF" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
      [TExp Double] -> TExp Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TExp Double) -> [SegHistSlug] -> [TExp Double]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 (TExp Int64 -> TExp Double)
-> (SegHistSlug -> TExp Int64) -> SegHistSlug -> TExp Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histRaceFactor (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
        TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ [SegHistSlug] -> TExp Double
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs

  TExp Int32
hist_el_size <- String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_el_size" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ [TExp Int32] -> TExp Int32
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int32] -> TExp Int32) -> [TExp Int32] -> TExp Int32
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> TExp Int32) -> [SegHistSlug] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> TExp Int32
slugElAvgSize [SegHistSlug]
slugs

  TExp Double
hist_C_max <-
    String -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_C_max" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
      TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TExp Double -> TExp Double) -> TExp Double -> TExp Double
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_H TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_k_ct_min

  TExp Int32
hist_M_min <-
    String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_M_min" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
      TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Double -> TExp Int64
forall t v. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double -> TExp Int64) -> TExp Double -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_C_max

  -- Querying L2 cache size is not reliable.  Instead we provide a
  -- tunable knob with a hopefully sane default.
  let hist_L2_def :: Int64
hist_L2_def = Int64
4 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
1024 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
1024
  TV Any
hist_L2 <- String -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"L2_size" PrimType
int32
  Maybe Name
entry <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  -- Equivalent to F_L2*L2 in paper.
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    VName -> Name -> SizeClass -> HostOp
Imp.GetSize
      (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
hist_L2)
      (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
entry (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (VName -> String
forall a. Pretty a => a -> String
pretty (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
hist_L2)))
      (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$ Name -> Int64 -> SizeClass
Imp.SizeBespoke (String -> Name
nameFromString String
"L2_for_histogram") Int64
hist_L2_def

  let hist_L2_ln_sz :: TExp Double
hist_L2_ln_sz = TExp Double
16 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
4 -- L2 cache line size approximation
  TExp Double
hist_RACE_exp <-
    String -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_RACE_exp" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
      TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMax64 TExp Double
1 (TExp Double -> TExp Double) -> TExp Double -> TExp Double
forall a b. (a -> b) -> a -> b
$
        (TExp Double
hist_k_RF TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RF)
          TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ (TExp Double
hist_L2_ln_sz TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_el_size)

  TV Int32
hist_S <- String -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"hist_S" PrimType
int32

  -- For sparse histograms (H exceeds N) we only want a single chunk.
  TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
    (TExp Int64
hist_N TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
hist_H)
    (TV Int32
hist_S TV Int32 -> TExp Int32 -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32))
    (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ TV Int32
hist_S
      TV Int32 -> TExp Int32 -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- case Passage
passage of
        Passage
MayBeMultiPass ->
          TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
            (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_min TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_el_size)
              TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Double -> TExp Int64
forall t v. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double
hist_F_L2 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TPrimExp Any ExpLeaf -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 (TV Any -> TPrimExp Any ExpLeaf
forall t. TV t -> TExp t
tvExp TV Any
hist_L2) TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RACE_exp)
        Passage
MustBeSinglePass ->
          TExp Int32
1

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race expansion factor (RACE^exp)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Double -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_RACE_exp
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S

  [[TExp Int64] -> InKernelGen ()]
histograms <-
    (Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> [[TExp Int64] -> InKernelGen ()]
forall a b. (a, b) -> b
snd
      ((Maybe Locking, [[TExp Int64] -> InKernelGen ()])
 -> [[TExp Int64] -> InKernelGen ()])
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> ImpM GPUMem HostEnv HostOp [[TExp Int64] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe Locking
 -> SegHistSlug
 -> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ()))
-> Maybe Locking
-> [SegHistSlug]
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (Maybe Locking, [[TExp Int64] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (TPrimExp Any ExpLeaf
-> TExp Int32
-> TExp Int32
-> TExp Double
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp (TV Any -> TPrimExp Any ExpLeaf
forall t. TV t -> TExp t
tvExp TV Any
hist_L2) TExp Int32
hist_M_min (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S) TExp Double
hist_RACE_exp)
        Maybe Locking
forall a. Maybe a
Nothing
        [SegHistSlug]
slugs

  (TExp Int32, [[TExp Int64] -> InKernelGen ()])
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_S, [[TExp Int64] -> InKernelGen ()]
histograms)
  where
    hist_k_ct_min :: TExp Double
hist_k_ct_min = TExp Double
2 -- Chosen experimentally
    hist_k_RF :: TExp Double
hist_k_RF = TExp Double
0.75 -- Chosen experimentally
    hist_F_L2 :: TExp Double
hist_F_L2 = TExp Double
0.4 -- Chosen experimentally
    r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
    t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElAvgSize :: SegHistSlug -> TExp Int32
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
      case AtomicUpdate GPUMem KernelEnv
do_op of
        AtomicLocking {} ->
          SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32
1 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ [Type] -> TExp Int32
forall i a. Num i => [a] -> i
genericLength (LambdaT GPUMem -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)))
        AtomicUpdate GPUMem KernelEnv
_ ->
          SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` [Type] -> TExp Int32
forall i a. Num i => [a] -> i
genericLength (LambdaT GPUMem -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op))

    -- "Average element size" as computed by a formula that also takes
    -- locking into account.
    slugElSize :: SegHistSlug -> TExp Int32
slugElSize (SegHistSlug HistOp GPUMem
op TV Int64
_ [SubhistosInfo]
_ AtomicUpdate GPUMem KernelEnv
do_op) =
      case AtomicUpdate GPUMem KernelEnv
do_op of
        AtomicLocking {} ->
          TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
            Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$
              [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                (Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$
                  PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: LambdaT GPUMem -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)
        AtomicUpdate GPUMem KernelEnv
_ ->
          TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
            Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$
              [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
                (Type -> Count Bytes (TExp Int64))
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes (TExp Int64)
typeSize (Type -> Count Bytes (TExp Int64))
-> (Type -> Type) -> Type -> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)) ([Type] -> [Count Bytes (TExp Int64)])
-> [Type] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$
                  LambdaT GPUMem -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op)

    onOp :: TPrimExp Any ExpLeaf
-> TExp Int32
-> TExp Int32
-> TExp Double
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp TPrimExp Any ExpLeaf
hist_L2 TExp Int32
hist_M_min TExp Int32
hist_S TExp Double
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
      let SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op = SegHistSlug
slug
          hist_H :: TExp Int64
hist_H = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histWidth HistOp GPUMem
op

      TExp Int64
hist_H_chk <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_H_chk" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S

      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Chunk size (H_chk)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H_chk

      TExp Double
hist_k_max <-
        String -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_k_max" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
          TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64
            (TExp Double
hist_F_L2 TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* (TPrimExp Any ExpLeaf -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TPrimExp Any ExpLeaf
hist_L2 TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 (SegHistSlug -> TExp Int32
slugElSize SegHistSlug
slug)) TExp Double -> TExp Double -> TExp Double
forall a. Num a => a -> a -> a
* TExp Double
hist_RACE_exp)
            (TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_N)
            TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T

      TExp Int64
hist_u <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_u" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
        case AtomicUpdate GPUMem KernelEnv
do_op of
          AtomicPrim {} -> TExp Int64
2
          AtomicUpdate GPUMem KernelEnv
_ -> TExp Int64
1

      TExp Double
hist_C <-
        String -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_C" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
          TExp Double -> TExp Double -> TExp Double
forall v.
TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMin64 (TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T) (TExp Double -> TExp Double) -> TExp Double -> TExp Double
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 (TExp Int64
hist_u TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk) TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_k_max

      -- Number of subhistograms per result histogram.
      TExp Int32
hist_M <- String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_M" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
        case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
          AtomicPrim {} -> TExp Int32
1
          AtomicUpdate GPUMem KernelEnv
_ -> TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
hist_M_min (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Double -> TExp Int64
forall t v. TPrimExp t v -> TPrimExp Int64 v
t64 (TExp Double -> TExp Int64) -> TExp Double -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int32
hist_T TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Double
hist_C

      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Elements/thread in L2 cache (k_max)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Double -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_k_max
      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_M
      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Double -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Double
hist_C

      -- num_subhistos is the variable we use to communicate back.
      TV Int64
num_subhistos TV Int64 -> TExp Int64 -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M

      -- Initialise sub-histograms.
      --
      -- If hist_M is 1, then we just reuse the original
      -- destination.  The idea is to avoid a copy if we are writing a
      -- small number of values into a very large prior histogram.
      [VName]
dests <- [(VName, SubhistosInfo)]
-> ((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubhistosInfo] -> [(VName, SubhistosInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op) [SubhistosInfo]
subhisto_info) (((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
 -> ImpM GPUMem HostEnv HostOp [VName])
-> ((VName, SubhistosInfo) -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubhistosInfo
info) -> do
        MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
dest

        VName
sub_mem <-
          (MemLoc -> VName)
-> ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemLoc -> VName
memLocName (ImpM GPUMem HostEnv HostOp MemLoc
 -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp MemLoc
-> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
            ArrayEntry -> MemLoc
entryArrayLoc
              (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (SubhistosInfo -> VName
subhistosArray SubhistosInfo
info)

        let unitHistoCase :: CallKernelGen ()
unitHistoCase =
              Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
                VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLoc -> VName
memLocName MemLoc
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                  String -> Space
Space String
"device"

            multiHistoCase :: CallKernelGen ()
multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info

        TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TExp Int32
hist_M TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase

        VName -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM GPUMem HostEnv HostOp VName)
-> VName -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info

      (Maybe Locking
l', [TExp Int64] -> InKernelGen ()
do_op') <- Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug

      (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l', [TExp Int64] -> InKernelGen ()
do_op')

histKernelGlobalPass ::
  [PatElem GPUMem] ->
  Count NumGroups (Imp.TExp Int64) ->
  Count GroupSize (Imp.TExp Int64) ->
  SegSpace ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  [[Imp.TExp Int64] -> InKernelGen ()] ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  CallKernelGen ()
histKernelGlobalPass :: [PatElem GPUMem]
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass [PatElem GPUMem]
map_pes Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody [[TExp Int64] -> InKernelGen ()]
histograms TExp Int32
hist_S TExp Int32
chk_i = do
  let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      space_sizes_64 :: [TExp Int64]
space_sizes_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64 -> TExp Int64)
-> (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
space_sizes
      total_w_64 :: TExp Int64
total_w_64 = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
space_sizes_64

  [TExp Int64]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> ImpM GPUMem HostEnv HostOp [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histWidth (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM GPUMem HostEnv HostOp (TExp Int64))
 -> ImpM GPUMem HostEnv HostOp [TExp Int64])
-> (SubExp -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> ImpM GPUMem HostEnv HostOp [TExp Int64]
forall a b. (a -> b) -> a -> b
$ \SubExp
w ->
    String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_H_chk" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
w TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S

  String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_global" Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

    -- Compute subhistogram index for each thread, per histogram.
    [TExp Int32]
subhisto_inds <- [SegHistSlug]
-> (SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegHistSlug]
slugs ((SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
 -> ImpM GPUMem KernelEnv KernelOp [TExp Int32])
-> (SegHistSlug -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp [TExp Int32]
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug ->
      String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"subhisto_ind" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
        KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
          TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` ( KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
                     TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp (SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug))
                 )

    -- Loop over flat offsets into the input and output.  The
    -- calculation is done with 64-bit integers to avoid overflow,
    -- but the final unflattened segment indexes are 32 bit.
    let gtid :: TExp Int64
gtid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
        num_threads :: TExp Int64
num_threads = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelNumThreads KernelConstants
constants
    TExp Int64
-> TExp Int64
-> TExp Int64
-> (TExp Int64 -> InKernelGen ())
-> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp Int64
gtid TExp Int64
num_threads TExp Int64
total_w_64 ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
offset -> do
      -- Construct segment indices.
      (VName -> TExp Int32 -> InKernelGen ())
-> [VName] -> [TExp Int32] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int32 -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
space_is ([TExp Int32] -> InKernelGen ()) -> [TExp Int32] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        (TExp Int64 -> TExp Int32) -> [TExp Int64] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 ([TExp Int64] -> [TExp Int32]) -> [TExp Int64] -> [TExp Int32]
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
space_sizes_64 TExp Int64
offset

      -- We execute the bucket function once and update each histogram serially.
      -- We apply the bucket function if j=offset+ltid is less than
      -- num_elements.  This also involves writing to the mapout
      -- arrays.
      let input_in_bounds :: TExp Bool
input_in_bounds = TExp Int64
offset TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
total_w_64

      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
input_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem GPUMem]
[PatElemT LetDecMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody

          String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(PatElemT LetDecMem, KernelResult)]
-> ((PatElemT LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [KernelResult] -> [(PatElemT LetDecMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem GPUMem]
[PatElemT LetDecMem]
map_pes [KernelResult]
map_res) (((PatElemT LetDecMem, KernelResult) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, KernelResult
res) ->
              VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
                (((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TExp Int64
Imp.vi64 (VName -> TExp Int64)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
                (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
                []

          let ([KernelResult]
buckets, [KernelResult]
vs) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [KernelResult]
red_res
              perOp :: [KernelResult] -> [[KernelResult]]
perOp = [Int] -> [KernelResult] -> [[KernelResult]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [KernelResult] -> [[KernelResult]])
-> [Int] -> [KernelResult] -> [[KernelResult]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest (HistOp GPUMem -> [VName])
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs

          String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(HistOp GPUMem, [TExp Int64] -> InKernelGen (), KernelResult,
  [KernelResult], TExp Int32, TExp Int64)]
-> ((HistOp GPUMem, [TExp Int64] -> InKernelGen (), KernelResult,
     [KernelResult], TExp Int32, TExp Int64)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [[TExp Int64] -> InKernelGen ()]
-> [KernelResult]
-> [[KernelResult]]
-> [TExp Int32]
-> [TExp Int64]
-> [(HistOp GPUMem, [TExp Int64] -> InKernelGen (), KernelResult,
     [KernelResult], TExp Int32, TExp Int64)]
forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [[TExp Int64] -> InKernelGen ()]
histograms [KernelResult]
buckets ([KernelResult] -> [[KernelResult]]
perOp [KernelResult]
vs) [TExp Int32]
subhisto_inds [TExp Int64]
hist_H_chks) (((HistOp GPUMem, [TExp Int64] -> InKernelGen (), KernelResult,
   [KernelResult], TExp Int32, TExp Int64)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp GPUMem, [TExp Int64] -> InKernelGen (), KernelResult,
     [KernelResult], TExp Int32, TExp Int64)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              \( HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT GPUMem
lam,
                 [TExp Int64] -> InKernelGen ()
do_op,
                 KernelResult
bucket,
                 [KernelResult]
vs',
                 TExp Int32
subhisto_ind,
                 TExp Int64
hist_H_chk
                 ) -> do
                  let chk_beg :: TExp Int64
chk_beg = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk
                      bucket' :: TExp Int64
bucket' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelResult -> SubExp
kernelResultSubExp KernelResult
bucket
                      dest_w' :: TExp Int64
dest_w' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
dest_w
                      bucket_in_bounds :: TExp Bool
bucket_in_bounds =
                        TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
bucket'
                          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
hist_H_chk)
                          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
dest_w'
                      vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
vs') ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
lam

                  TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                    let bucket_is :: [TExp Int64]
bucket_is =
                          (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is)
                            [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_ind, TExp Int64
bucket']
                    [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
lam
                    Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
                      [(Param LetDecMem, KernelResult)]
-> ((Param LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [KernelResult] -> [(Param LetDecMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [KernelResult]
vs') (((Param LetDecMem, KernelResult) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, KernelResult
res) ->
                        VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) [TExp Int64]
is
                      [TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)

histKernelGlobal ::
  [PatElem GPUMem] ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  CallKernelGen ()
histKernelGlobal :: [PatElem GPUMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem GPUMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
  let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
  let ([VName]
_space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      num_threads :: TExp Int32
num_threads = TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using global memory" Maybe Exp
forall a. Maybe a
Nothing

  (TExp Int32
hist_S, [[TExp Int64] -> InKernelGen ()]
histograms) <-
    Passage
-> TExp Int32
-> TExp Int64
-> [SegHistSlug]
-> CallKernelGen (TExp Int32, [[TExp Int64] -> InKernelGen ()])
prepareIntermediateArraysGlobal
      (KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody)
      TExp Int32
num_threads
      (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes)
      [SegHistSlug]
slugs

  String
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
    [PatElem GPUMem]
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> [[TExp Int64] -> InKernelGen ()]
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelGlobalPass
      [PatElem GPUMem]
map_pes
      Count NumGroups (TExp Int64)
num_groups'
      Count GroupSize (TExp Int64)
group_size'
      SegSpace
space
      [SegHistSlug]
slugs
      KernelBody GPUMem
kbody
      [[TExp Int64] -> InKernelGen ()]
histograms
      TExp Int32
hist_S
      TExp Int32
chk_i

type InitLocalHistograms =
  [ ( [VName],
      SubExp ->
      InKernelGen
        ( [VName],
          [Imp.TExp Int64] -> InKernelGen ()
        )
    )
  ]

prepareIntermediateArraysLocal ::
  TV Int32 ->
  Count NumGroups (Imp.TExp Int64) ->
  SegSpace ->
  [SegHistSlug] ->
  CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: TV Int32
-> Count NumGroups (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group Count NumGroups (TExp Int64)
groups_per_segment SegSpace
space [SegHistSlug]
slugs = do
  TExp Int64
num_segments <-
    String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"num_segments" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
      [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  (SegHistSlug
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      ([VName],
       SubExp
       -> ImpM
            GPUMem
            KernelEnv
            KernelOp
            ([VName], [TExp Int64] -> InKernelGen ())))
-> [SegHistSlug] -> CallKernelGen InitLocalHistograms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TExp Int64
-> SegHistSlug
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TExp Int64] -> InKernelGen ()))
onOp TExp Int64
num_segments) [SegHistSlug]
slugs
  where
    onOp :: TExp Int64
-> SegHistSlug
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TExp Int64] -> InKernelGen ()))
onOp TExp Int64
num_segments (SegHistSlug HistOp GPUMem
op TV Int64
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate GPUMem KernelEnv
do_op) = do
      TV Int64
num_subhistos TV Int64 -> TExp Int64 -> CallKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_segments

      Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of subhistograms in global memory" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
          Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_subhistos

      SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op <-
        case AtomicUpdate GPUMem KernelEnv
do_op of
          AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
  -> ImpM
       GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
    -> ImpM
         GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
 -> SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate GPUMem KernelEnv
f
          AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
  -> ImpM
       GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
    -> ImpM
         GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. a -> b -> a
const (ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
 -> SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
-> SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate GPUMem KernelEnv
f
          AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f -> (SubExp
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
  -> ImpM
       GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
 -> ImpM
      GPUMem
      HostEnv
      HostOp
      (SubExp
       -> ImpM
            GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)))
-> (SubExp
    -> ImpM
         GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     (SubExp
      -> ImpM
           GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
            let lock_shape :: Shape
lock_shape =
                  [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$
                    TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
                    Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)
                      [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp
hist_H_chk]

            let dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
lock_shape

            VName
locks <- String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"locks" PrimType
int32 Shape
lock_shape (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

            String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [TExp Int64] -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TExp Int64]
dims (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
                VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
locks [TExp Int64]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

            DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return (DoAtomicUpdate GPUMem KernelEnv
 -> ImpM
      GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv))
-> DoAtomicUpdate GPUMem KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ Locking -> DoAtomicUpdate GPUMem KernelEnv
f (Locking -> DoAtomicUpdate GPUMem KernelEnv)
-> Locking -> DoAtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 [TExp Int64] -> [TExp Int64]
forall a. a -> a
id

      -- Initialise local-memory sub-histograms.  These are
      -- represented as two-dimensional arrays.
      let init_local_subhistos :: SubExp
-> ImpM
     GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos SubExp
hist_H_chk = do
            [VName]
local_subhistos <-
              [Type]
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp GPUMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp GPUMem
op) ((Type -> ImpM GPUMem KernelEnv KernelOp VName)
 -> ImpM GPUMem KernelEnv KernelOp [VName])
-> (Type -> ImpM GPUMem KernelEnv KernelOp VName)
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
                let sub_local_shape :: Shape
sub_local_shape =
                      [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int32
num_subhistos_per_group]
                        Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t Shape -> SubExp -> Shape
forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` SubExp
hist_H_chk)
                String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray
                  String
"subhistogram_local"
                  (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
                  Shape
sub_local_shape
                  (String -> Space
Space String
"local")

            DoAtomicUpdate GPUMem KernelEnv
do_op' <- SubExp
-> ImpM GPUMem KernelEnv KernelOp (DoAtomicUpdate GPUMem KernelEnv)
mk_op SubExp
hist_H_chk

            ([VName], [TExp Int64] -> InKernelGen ())
-> ImpM
     GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
local_subhistos, DoAtomicUpdate GPUMem KernelEnv
do_op' (String -> Space
Space String
"local") [VName]
local_subhistos)

      -- Initialise global-memory sub-histograms.
      [VName]
glob_subhistos <- [SubhistosInfo]
-> (SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubhistosInfo]
subhisto_info ((SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
 -> ImpM GPUMem HostEnv HostOp [VName])
-> (SubhistosInfo -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubhistosInfo
info -> do
        SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
        VName -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM GPUMem HostEnv HostOp VName)
-> VName -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info

      ([VName],
 SubExp
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      ([VName], [TExp Int64] -> InKernelGen ()))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TExp Int64] -> InKernelGen ()))
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
glob_subhistos, SubExp
-> ImpM
     GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos)

histKernelLocalPass ::
  TV Int32 ->
  Count NumGroups (Imp.TExp Int64) ->
  [PatElem GPUMem] ->
  Count NumGroups (Imp.TExp Int64) ->
  Count GroupSize (Imp.TExp Int64) ->
  SegSpace ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  InitLocalHistograms ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  CallKernelGen ()
histKernelLocalPass :: TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem GPUMem]
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
  TV Int32
num_subhistos_per_group_var
  Count NumGroups (TExp Int64)
groups_per_segment
  [PatElem GPUMem]
map_pes
  Count NumGroups (TExp Int64)
num_groups
  Count GroupSize (TExp Int64)
group_size
  SegSpace
space
  [SegHistSlug]
slugs
  KernelBody GPUMem
kbody
  InitLocalHistograms
init_histograms
  TExp Int32
hist_S
  TExp Int32
chk_i = do
    let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
        segment_is :: [VName]
segment_is = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is
        segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
        (VName
i_in_segment, SubExp
segment_size) = [(VName, SubExp)] -> (VName, SubExp)
forall a. [a] -> a
last ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
        num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var
        segment_size' :: TExp Int64
segment_size' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
segment_size

    TExp Int64
num_segments <-
      String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"num_segments" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
        [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
segment_dims

    [TV Int64]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histWidth (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM GPUMem HostEnv HostOp (TV Int64))
 -> ImpM GPUMem HostEnv HostOp [TV Int64])
-> (SubExp -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> ImpM GPUMem HostEnv HostOp [TV Int64]
forall a b. (a -> b) -> a -> b
$ \SubExp
w ->
      String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"hist_H_chk" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
w TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_S

    [([TExp Int64], TExp Int64, TExp Int32)]
histo_sizes <- [(SegHistSlug, TV Int64)]
-> ((SegHistSlug, TV Int64)
    -> ImpM
         GPUMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32))
-> ImpM
     GPUMem HostEnv HostOp [([TExp Int64], TExp Int64, TExp Int32)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SegHistSlug] -> [TV Int64] -> [(SegHistSlug, TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegHistSlug]
slugs [TV Int64]
hist_H_chks) (((SegHistSlug, TV Int64)
  -> ImpM
       GPUMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32))
 -> ImpM
      GPUMem HostEnv HostOp [([TExp Int64], TExp Int64, TExp Int32)])
-> ((SegHistSlug, TV Int64)
    -> ImpM
         GPUMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32))
-> ImpM
     GPUMem HostEnv HostOp [([TExp Int64], TExp Int64, TExp Int32)]
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, TV Int64
hist_H_chk) -> do
      let histo_dims :: [TExp Int64]
histo_dims =
            TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
:
            (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
      TExp Int64
histo_size <-
        String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"histo_size" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
histo_dims
      let group_hists_size :: TExp Int64
group_hists_size =
            TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_subhistos_per_group TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
histo_size
      TExp Int32
init_per_thread <-
        String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"init_per_thread" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64
group_hists_size TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size
      ([TExp Int64], TExp Int64, TExp Int32)
-> ImpM
     GPUMem HostEnv HostOp ([TExp Int64], TExp Int64, TExp Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TExp Int64]
histo_dims, TExp Int64
histo_size, TExp Int32
init_per_thread)

    String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_local" Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_segments) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
        KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

        TExp Int32
flat_segment_id <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_segment_id" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment)
        TExp Int32
gid_in_segment <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"gid_in_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment)
        -- This pgtid is kind of a "virtualised physical" gtid - not the
        -- same thing as the gtid used for the SegHist itself.
        TExp Int32
pgtid_in_segment <-
          String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"pgtid_in_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
            TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
              TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
        TExp Int32
threads_per_segment <-
          String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"threads_per_segment" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
            TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants

        -- Set segment indices.
        (VName -> TExp Int64 -> InKernelGen ())
-> [VName] -> [TExp Int64] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
segment_is ([TExp Int64] -> InKernelGen ()) -> [TExp Int64] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
segment_dims) (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
flat_segment_id

        [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms <- [(([VName],
   SubExp
   -> ImpM
        GPUMem
        KernelEnv
        KernelOp
        ([VName], [TExp Int64] -> InKernelGen ())),
  TV Int64)]
-> ((([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TExp Int64] -> InKernelGen ())),
     TV Int64)
    -> ImpM
         GPUMem
         KernelEnv
         KernelOp
         ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()))
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (InitLocalHistograms
-> [TV Int64]
-> [(([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TExp Int64] -> InKernelGen ())),
     TV Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip InitLocalHistograms
init_histograms [TV Int64]
hist_H_chks) (((([VName],
    SubExp
    -> ImpM
         GPUMem
         KernelEnv
         KernelOp
         ([VName], [TExp Int64] -> InKernelGen ())),
   TV Int64)
  -> ImpM
       GPUMem
       KernelEnv
       KernelOp
       ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()))
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())])
-> ((([VName],
      SubExp
      -> ImpM
           GPUMem
           KernelEnv
           KernelOp
           ([VName], [TExp Int64] -> InKernelGen ())),
     TV Int64)
    -> ImpM
         GPUMem
         KernelEnv
         KernelOp
         ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()))
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
forall a b. (a -> b) -> a -> b
$
          \(([VName]
glob_subhistos, SubExp
-> ImpM
     GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos), TV Int64
hist_H_chk) -> do
            ([VName]
local_subhistos, [TExp Int64] -> InKernelGen ()
do_op) <- SubExp
-> ImpM
     GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
init_local_subhistos (SubExp
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      ([VName], [TExp Int64] -> InKernelGen ()))
-> SubExp
-> ImpM
     GPUMem KernelEnv KernelOp ([VName], [TExp Int64] -> InKernelGen ())
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
hist_H_chk
            ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
glob_subhistos [VName]
local_subhistos, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
do_op)

        -- Find index of local subhistograms updated by this thread.  We
        -- try to ensure, as much as possible, that threads in the same
        -- warp use different subhistograms, to avoid conflicts.
        TExp Int32
thread_local_subhisto_i <-
          String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"thread_local_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
            KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int32
num_subhistos_per_group

        let onSlugs :: (SegHistSlug
 -> [(VName, VName)]
 -> TExp Int64
 -> [TExp Int64]
 -> TExp Int64
 -> TExp Int32
 -> InKernelGen ())
-> InKernelGen ()
onSlugs SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f =
              [(SegHistSlug,
  ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
  ([TExp Int64], TExp Int64, TExp Int32))]
-> ((SegHistSlug,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     ([TExp Int64], TExp Int64, TExp Int32))
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
-> [([TExp Int64], TExp Int64, TExp Int32)]
-> [(SegHistSlug,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     ([TExp Int64], TExp Int64, TExp Int32))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [([TExp Int64], TExp Int64, TExp Int32)]
histo_sizes) (((SegHistSlug,
   ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
   ([TExp Int64], TExp Int64, TExp Int32))
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SegHistSlug,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     ([TExp Int64], TExp Int64, TExp Int32))
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                \(SegHistSlug
slug, ([(VName, VName)]
dests, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
_), ([TExp Int64]
histo_dims, TExp Int64
histo_size, TExp Int32
init_per_thread)) ->
                  SegHistSlug
-> [(VName, VName)]
-> TExp Int64
-> [TExp Int64]
-> TExp Int64
-> TExp Int32
-> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk) [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread

        let onAllHistograms :: (VName
 -> VName
 -> HistOp GPUMem
 -> SubExp
 -> TExp Int32
 -> TExp Int32
 -> [TExp Int64]
 -> [TExp Int64]
 -> InKernelGen ())
-> InKernelGen ()
onAllHistograms VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ()
f =
              (SegHistSlug
 -> [(VName, VName)]
 -> TExp Int64
 -> [TExp Int64]
 -> TExp Int64
 -> TExp Int32
 -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
  -> [(VName, VName)]
  -> TExp Int64
  -> [TExp Int64]
  -> TExp Int64
  -> TExp Int32
  -> InKernelGen ())
 -> InKernelGen ())
-> (SegHistSlug
    -> [(VName, VName)]
    -> TExp Int64
    -> [TExp Int64]
    -> TExp Int64
    -> TExp Int32
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
histo_size TExp Int32
init_per_thread -> do
                let group_hists_size :: TExp Int32
group_hists_size = TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size

                [((VName, VName), SubExp)]
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, VName)] -> [SubExp] -> [((VName, VName), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)) ((((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  \((VName
dest_global, VName
dest_local), SubExp
ne) ->
                    String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"local_i" TExp Int32
init_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
                      TExp Int32
j <-
                        String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
                          TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
                            TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
                      TExp Int32
j_offset <-
                        String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"j_offset" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
                          TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
gid_in_segment TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
j

                      TExp Int32
local_subhisto_i <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"local_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
                      let local_bucket_is :: [TExp Int64]
local_bucket_is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size
                          global_bucket_is :: [TExp Int64]
global_bucket_is =
                            [TExp Int64] -> TExp Int64
forall a. [a] -> a
head [TExp Int64]
local_bucket_is TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
:
                            [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
tail [TExp Int64]
local_bucket_is
                      TExp Int32
global_subhisto_i <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_subhisto_i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
j_offset TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
histo_size

                      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
group_hists_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                        VName
-> VName
-> HistOp GPUMem
-> SubExp
-> TExp Int32
-> TExp Int32
-> [TExp Int64]
-> [TExp Int64]
-> InKernelGen ()
f
                          VName
dest_local
                          VName
dest_global
                          (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)
                          SubExp
ne
                          TExp Int32
local_subhisto_i
                          TExp Int32
global_subhisto_i
                          [TExp Int64]
local_bucket_is
                          [TExp Int64]
global_bucket_is

        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"initialize histograms in local memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (VName
 -> VName
 -> HistOp GPUMem
 -> SubExp
 -> TExp Int32
 -> TExp Int32
 -> [TExp Int64]
 -> [TExp Int64]
 -> InKernelGen ())
-> InKernelGen ()
onAllHistograms ((VName
  -> VName
  -> HistOp GPUMem
  -> SubExp
  -> TExp Int32
  -> TExp Int32
  -> [TExp Int64]
  -> [TExp Int64]
  -> InKernelGen ())
 -> InKernelGen ())
-> (VName
    -> VName
    -> HistOp GPUMem
    -> SubExp
    -> TExp Int32
    -> TExp Int32
    -> [TExp Int64]
    -> [TExp Int64]
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
dest_local VName
dest_global HistOp GPUMem
op SubExp
ne TExp Int32
local_subhisto_i TExp Int32
global_subhisto_i [TExp Int64]
local_bucket_is [TExp Int64]
global_bucket_is ->
            String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"First subhistogram is initialised from global memory; others with neutral element." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              let global_is :: [TExp Int64]
global_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
segment_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64
0] [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
                  local_is :: [TExp Int64]
local_is = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_subhisto_i TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is
              TExp Bool -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
                (TExp Int32
global_subhisto_i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0)
                (VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest_local [TExp Int64]
local_is (VName -> SubExp
Var VName
dest_global) [TExp Int64]
global_is)
                ( Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
                    VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest_local ([TExp Int64]
local_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is) SubExp
ne []
                )

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

        TExp Int32
-> TExp Int32
-> TExp Int32
-> (TExp Int32 -> InKernelGen ())
-> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp Int32
pgtid_in_segment TExp Int32
threads_per_segment (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
segment_size') ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
ie -> do
          VName -> TExp Int32 -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
i_in_segment TExp Int32
ie

          -- We execute the bucket function once and update each histogram
          -- serially.  This also involves writing to the mapout arrays if
          -- this is the first chunk.

          Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
            let ([SubExp]
red_res, [SubExp]
map_res) =
                  Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem GPUMem]
[PatElemT LetDecMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
                    (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
                ([SubExp]
buckets, [SubExp]
vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [SubExp]
red_res
                perOp :: [SubExp] -> [[SubExp]]
perOp = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [SubExp] -> [[SubExp]])
-> [Int] -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest (HistOp GPUMem -> [VName])
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs

            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chk_i TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [(PatElemT LetDecMem, SubExp)]
-> ((PatElemT LetDecMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [SubExp] -> [(PatElemT LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem GPUMem]
[PatElemT LetDecMem]
map_pes [SubExp]
map_res) (((PatElemT LetDecMem, SubExp) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElemT LetDecMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, SubExp
se) ->
                  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                    (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
                    ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
space_is)
                    SubExp
se
                    []

            [(HistOp GPUMem,
  ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
  SubExp, [SubExp])]
-> ((HistOp GPUMem,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     SubExp, [SubExp])
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp GPUMem]
-> [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
-> [SubExp]
-> [[SubExp]]
-> [(HistOp GPUMem,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     SubExp, [SubExp])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 ((SegHistSlug -> HistOp GPUMem) -> [SegHistSlug] -> [HistOp GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp GPUMem
slugOp [SegHistSlug]
slugs) [([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ())]
histograms [SubExp]
buckets ([SubExp] -> [[SubExp]]
perOp [SubExp]
vs)) (((HistOp GPUMem,
   ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
   SubExp, [SubExp])
  -> InKernelGen ())
 -> InKernelGen ())
-> ((HistOp GPUMem,
     ([(VName, VName)], TV Int64, [TExp Int64] -> InKernelGen ()),
     SubExp, [SubExp])
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              \( HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT GPUMem
lam,
                 ([(VName, VName)]
_, TV Int64
hist_H_chk, [TExp Int64] -> InKernelGen ()
do_op),
                 SubExp
bucket,
                 [SubExp]
vs'
                 ) -> do
                  let chk_beg :: TExp Int64
chk_beg = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk
                      bucket' :: TExp Int64
bucket' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
bucket
                      dest_w' :: TExp Int64
dest_w' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
dest_w
                      bucket_in_bounds :: TExp Bool
bucket_in_bounds =
                        TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
dest_w'
                          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
bucket'
                          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
chk_beg TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_H_chk)
                      bucket_is :: [TExp Int64]
bucket_is = [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
thread_local_subhisto_i, TExp Int64
bucket' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
chk_beg]
                      vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
lam

                  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                      [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT GPUMem
lam
                      Shape -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
                        [(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
vs') (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
                          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [TExp Int64]
is
                        [TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal

        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Compact the multiple local memory subhistograms to result in global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (SegHistSlug
 -> [(VName, VName)]
 -> TExp Int64
 -> [TExp Int64]
 -> TExp Int64
 -> TExp Int32
 -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
  -> [(VName, VName)]
  -> TExp Int64
  -> [TExp Int64]
  -> TExp Int64
  -> TExp Int32
  -> InKernelGen ())
 -> InKernelGen ())
-> (SegHistSlug
    -> [(VName, VName)]
    -> TExp Int64
    -> [TExp Int64]
    -> TExp Int64
    -> TExp Int32
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests TExp Int64
hist_H_chk [TExp Int64]
histo_dims TExp Int64
_histo_size TExp Int32
bins_per_thread -> do
            TV Int64
trunc_H <-
              String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"trunc_H" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
                TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TExp Int64
hist_H_chk (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
                  SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histWidth (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug))
                    TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* [TExp Int64] -> TExp Int64
forall a. [a] -> a
head [TExp Int64]
histo_dims
            let trunc_histo_dims :: [TExp Int64]
trunc_histo_dims =
                  TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
trunc_H TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
:
                  (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape (SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug)))
            TExp Int32
trunc_histo_size <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"histo_size" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
trunc_histo_dims

            String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"local_i" TExp Int32
bins_per_thread ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
              TExp Int32
j <-
                String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"j" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
                  TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
                    TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
              TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
j TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
trunc_histo_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                -- We are responsible for compacting the flat bin 'j', which
                -- we immediately unflatten.
                let local_bucket_is :: [TExp Int64]
local_bucket_is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
histo_dims (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
j
                    global_bucket_is :: [TExp Int64]
global_bucket_is =
                      [TExp Int64] -> TExp Int64
forall a. [a] -> a
head [TExp Int64]
local_bucket_is TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chk_i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H_chk TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
:
                      [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
tail [TExp Int64]
local_bucket_is
                [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (LambdaT GPUMem -> [LParam GPUMem])
-> LambdaT GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> LambdaT GPUMem)
-> HistOp GPUMem -> LambdaT GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug
                let ([VName]
global_dests, [VName]
local_dests) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, VName)]
dests
                    ([Param LetDecMem]
xparams, [Param LetDecMem]
yparams) =
                      Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
local_dests) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$
                        LambdaT GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (LambdaT GPUMem -> [LParam GPUMem])
-> LambdaT GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> LambdaT GPUMem)
-> HistOp GPUMem -> LambdaT GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug

                String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Read values from subhistogram 0." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
xparams [VName]
local_dests) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
xp, VName
subhisto) ->
                    VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                      (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
xp)
                      []
                      (VName -> SubExp
Var VName
subhisto)
                      (TExp Int64
0 TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)

                String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Accumulate based on values in other subhistograms." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                  String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"subhisto_id" (TExp Int32
num_subhistos_per_group TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
subhisto_id -> do
                    [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
yparams [VName]
local_dests) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
yp, VName
subhisto) ->
                      VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
                        (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
yp)
                        []
                        (VName -> SubExp
Var VName
subhisto)
                        (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
subhisto_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1 TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
local_bucket_is)
                    [Param LetDecMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LetDecMem]
xparams (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody (LambdaT GPUMem -> Body GPUMem) -> LambdaT GPUMem -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp (HistOp GPUMem -> LambdaT GPUMem)
-> HistOp GPUMem -> LambdaT GPUMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp GPUMem
slugOp SegHistSlug
slug

                String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"Put final bucket value in global memory." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
                  let global_is :: [TExp Int64]
global_is =
                        (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
segment_is
                          [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment]
                          [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
global_bucket_is
                  [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
xparams [VName]
global_dests) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
xp, VName
global_dest) ->
                    VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
global_dest [TExp Int64]
global_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
xp) []

histKernelLocal ::
  TV Int32 ->
  Count NumGroups (Imp.TExp Int64) ->
  [PatElem GPUMem] ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  Imp.TExp Int32 ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  CallKernelGen ()
histKernelLocal :: TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem GPUMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TExp Int64)
groups_per_segment [PatElem GPUMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space TExp Int32
hist_S [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
  let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
      num_subhistos_per_group :: TExp Int32
num_subhistos_per_group = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
num_subhistos_per_group_var

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of local subhistograms per group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
num_subhistos_per_group

  InitLocalHistograms
init_histograms <-
    TV Int32
-> Count NumGroups (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal TV Int32
num_subhistos_per_group_var Count NumGroups (TExp Int64)
groups_per_segment SegSpace
space [SegHistSlug]
slugs

  String
-> TExp Int32
-> (TExp Int32 -> CallKernelGen ())
-> CallKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"chk_i" TExp Int32
hist_S ((TExp Int32 -> CallKernelGen ()) -> CallKernelGen ())
-> (TExp Int32 -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chk_i ->
    TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem GPUMem]
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> InitLocalHistograms
-> TExp Int32
-> TExp Int32
-> CallKernelGen ()
histKernelLocalPass
      TV Int32
num_subhistos_per_group_var
      Count NumGroups (TExp Int64)
groups_per_segment
      [PatElem GPUMem]
map_pes
      Count NumGroups (TExp Int64)
num_groups'
      Count GroupSize (TExp Int64)
group_size'
      SegSpace
space
      [SegHistSlug]
slugs
      KernelBody GPUMem
kbody
      InitLocalHistograms
init_histograms
      TExp Int32
hist_S
      TExp Int32
chk_i

-- | The maximum number of passes we are willing to accept for this
-- kind of atomic update.
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
  case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
    AtomicPrim DoAtomicUpdate GPUMem KernelEnv
_ -> Int
3
    AtomicCAS DoAtomicUpdate GPUMem KernelEnv
_ -> Int
4
    AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
_ -> Int
6

localMemoryCase ::
  [PatElem GPUMem] ->
  Imp.TExp Int32 ->
  SegSpace ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Imp.TExp Int32 ->
  [SegHistSlug] ->
  KernelBody GPUMem ->
  CallKernelGen (Imp.TExp Bool, CallKernelGen ())
localMemoryCase :: [PatElem GPUMem]
-> TExp Int32
-> SegSpace
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem GPUMem]
map_pes TExp Int32
hist_T SegSpace
space TExp Int64
hist_H TExp Int64
hist_el_size TExp Int64
hist_N TExp Int32
_ [SegHistSlug]
slugs KernelBody GPUMem
kbody = do
  let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
      segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
      segmented :: Bool
segmented = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims

  TV Int64
hist_L <- String -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"hist_L" PrimType
int32
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
hist_L) SizeClass
Imp.SizeLocalMemory

  TV Any
max_group_size <- String -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"max_group_size" PrimType
int32
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
max_group_size) SizeClass
Imp.SizeGroup
  let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
max_group_size
  Count NumGroups SubExp
num_groups <-
    (TV Int64 -> Count NumGroups SubExp)
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (TV Int64 -> SubExp) -> TV Int64 -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize) (ImpM GPUMem HostEnv HostOp (TV Int64)
 -> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM GPUMem HostEnv HostOp (TV Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$
      String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_groups" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
        TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
  let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count GroupSize SubExp
group_size

  let r64 :: TPrimExp t v -> TPrimExp Double v
r64 = PrimExp v -> TPrimExp Double v
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp v -> TPrimExp Double v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Double v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int64 FloatType
Float64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped
      t64 :: TPrimExp t v -> TPrimExp Int64 v
t64 = PrimExp v -> TPrimExp Int64 v
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp v -> TPrimExp Int64 v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> TPrimExp Int64 v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int64) (PrimExp v -> PrimExp v)
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped

  -- M approximation.
  TExp Double
hist_m' <-
    String -> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_m_prime" (TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double))
-> TExp Double -> ImpM GPUMem HostEnv HostOp (TExp Double)
forall a b. (a -> b) -> a -> b
$
      TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64
        ( TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64
            (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_L TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
hist_el_size))
            (TExp Int64
hist_N TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups'))
        )
        TExp Double -> TExp Double -> TExp Double
forall a. Fractional a => a -> a -> a
/ TExp Int64 -> TExp Double
forall t v. TPrimExp t v -> TPrimExp Double v
r64 TExp Int64
hist_H

  let hist_B :: TExp Int64
hist_B = Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'

  -- M in the paper, but not adjusted for asymptotic efficiency.
  TExp Int64
hist_M0 <-
    String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_M0" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
      TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Double -> TExp Int64
forall t v. TPrimExp t v -> TPrimExp Int64 v
t64 TExp Double
hist_m') TExp Int64
hist_B

  -- Minimal sequential chunking factor.
  let q_small :: TExp Int64
q_small = TExp Int64
2

  -- The number of segments/histograms produced..
  TExp Int64
hist_Nout <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_Nout" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
segment_dims

  TExp Int64
hist_Nin <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_Nin" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes

  -- Maximum M for work efficiency.
  TExp Int64
work_asymp_M_max <-
    if Bool
segmented
      then do
        TExp Int32
hist_T_hist_min <-
          String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_T_hist_min" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
            TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
              TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nin TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nout) (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_T)
                TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
hist_Nout

        -- Number of groups, rounded up.
        let r :: TExp Int32
r = TExp Int32
hist_T_hist_min TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
hist_B

        String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"work_asymp_M_max" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_Nin TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
r TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H)
      else
        String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"work_asymp_M_max" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
          (TExp Int64
hist_Nout TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_N)
            TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` ( (TExp Int64
q_small TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_H)
                       TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
                   )

  -- Number of subhistograms per result histogram.
  TV Int32
hist_M <- String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"hist_M" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 TExp Int64
hist_M0 TExp Int64
work_asymp_M_max

  -- hist_M may be zero (which we'll check for below), but we need it
  -- for some divisions first, so crudely make a nonzero form.
  let hist_M_nonzero :: TExp Int32
hist_M_nonzero = TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMax32 TExp Int32
1 (TExp Int32 -> TExp Int32) -> TExp Int32 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M

  -- "Cooperation factor" - the number of threads cooperatively
  -- working on the same (sub)histogram.
  TExp Int64
hist_C <-
    String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_C" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
      TExp Int64
hist_B TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
hist_M_nonzero

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local hist_M0" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_M0
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local work asymp M max" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
work_asymp_M_max
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local C" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_C
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local B" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_B
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local M" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local memory needed" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_el_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M)

  -- local_mem_needed is what we need to keep a single bucket in local
  -- memory - this is an absolute minimum.  We can fit anything else
  -- by doing multiple passes, although more than a few is
  -- (heuristically) not efficient.
  TExp Int64
local_mem_needed <-
    String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"local_mem_needed" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
      TExp Int64
hist_el_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M)
  TExp Int32
hist_S <-
    String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_S" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
      TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
        (TExp Int64
hist_H TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
local_mem_needed) TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_L
  let max_S :: TExp Int32
max_S = case KernelBody GPUMem -> Passage
bodyPassage KernelBody GPUMem
kbody of
        Passage
MustBeSinglePass -> TExp Int32
1
        Passage
MayBeMultiPass -> Int -> TExp Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int32) -> Int -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs

  Count NumGroups (TExp Int64)
groups_per_segment <-
    if Bool
segmented
      then
        (TExp Int64 -> Count NumGroups (TExp Int64))
-> ImpM GPUMem HostEnv HostOp (TExp Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TExp Int64 -> Count NumGroups (TExp Int64)
forall u e. e -> Count u e
Count (ImpM GPUMem HostEnv HostOp (TExp Int64)
 -> ImpM GPUMem HostEnv HostOp (Count NumGroups (TExp Int64)))
-> ImpM GPUMem HostEnv HostOp (TExp Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups (TExp Int64))
forall a b. (a -> b) -> a -> b
$
          String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"groups_per_segment" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
hist_Nout
      else Count NumGroups (TExp Int64)
-> ImpM GPUMem HostEnv HostOp (Count NumGroups (TExp Int64))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Count NumGroups (TExp Int64)
num_groups'

  -- We only use local memory if the number of updates per histogram
  -- at least matches the histogram size, as otherwise it is not
  -- asymptotically efficient.  This mostly matters for the segmented
  -- case.
  let pick_local :: TExp Bool
pick_local =
        TExp Int64
hist_Nin TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>=. TExp Int64
hist_H
          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int64
local_mem_needed TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
hist_L)
          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Int32
hist_S TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
max_S)
          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
hist_C TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
hist_B
          TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0

      run :: CallKernelGen ()
run = do
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using local memory" Maybe Exp
forall a. Maybe a
Nothing
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
hist_M
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_C
        Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_S
        Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Groups per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
groups_per_segment
        TV Int32
-> Count NumGroups (TExp Int64)
-> [PatElem GPUMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelLocal
          TV Int32
hist_M
          Count NumGroups (TExp Int64)
groups_per_segment
          [PatElem GPUMem]
map_pes
          Count NumGroups SubExp
num_groups
          Count GroupSize SubExp
group_size
          SegSpace
space
          TExp Int32
hist_S
          [SegHistSlug]
slugs
          KernelBody GPUMem
kbody

  (TExp Bool, CallKernelGen ())
-> CallKernelGen (TExp Bool, CallKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Bool
pick_local, CallKernelGen ()
run)

-- | Generate code for a segmented histogram called from the host.
compileSegHist ::
  Pat GPUMem ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  SegSpace ->
  [HistOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegHist :: Pat GPUMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegHist (Pat [PatElem GPUMem]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp GPUMem]
ops KernelBody GPUMem
kbody = do
  -- Most of this function is not the histogram part itself, but
  -- rather figuring out whether to use a local or global memory
  -- strategy, as well as collapsing the subhistograms produced (which
  -- are always in global memory, but their number may vary).
  let num_groups' :: Count NumGroups (TExp Int64)
num_groups' = (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count NumGroups SubExp
num_groups
      group_size' :: Count GroupSize (TExp Int64)
group_size' = (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Count GroupSize SubExp
group_size
      dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

      num_red_res :: Int
num_red_res = [HistOp GPUMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
      ([PatElemT LetDecMem]
all_red_pes, [PatElemT LetDecMem]
map_pes) = Int
-> [PatElemT LetDecMem]
-> ([PatElemT LetDecMem], [PatElemT LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res [PatElem GPUMem]
[PatElemT LetDecMem]
pes
      segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims

  ([Count Bytes (TExp Int64)]
op_hs, [Count Bytes (TExp Int64)]
op_seg_hs, [SegHistSlug]
slugs) <- [(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)]
-> ([Count Bytes (TExp Int64)], [Count Bytes (TExp Int64)],
    [SegHistSlug])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Count Bytes (TExp Int64), Count Bytes (TExp Int64),
   SegHistSlug)]
 -> ([Count Bytes (TExp Int64)], [Count Bytes (TExp Int64)],
     [SegHistSlug]))
-> ImpM
     GPUMem
     HostEnv
     HostOp
     [(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)]
-> ImpM
     GPUMem
     HostEnv
     HostOp
     ([Count Bytes (TExp Int64)], [Count Bytes (TExp Int64)],
      [SegHistSlug])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp GPUMem
 -> CallKernelGen
      (Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug))
-> [HistOp GPUMem]
-> ImpM
     GPUMem
     HostEnv
     HostOp
     [(Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegSpace
-> HistOp GPUMem
-> CallKernelGen
     (Count Bytes (TExp Int64), Count Bytes (TExp Int64), SegHistSlug)
computeHistoUsage SegSpace
space) [HistOp GPUMem]
ops
  TExp Int64
h <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"h" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TExp Int64)]
op_hs
  TExp Int64
seg_h <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"seg_h" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount (Count Bytes (TExp Int64) -> TExp Int64)
-> Count Bytes (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes (TExp Int64)]
op_seg_hs

  -- Check for emptyness to avoid division-by-zero.
  TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TExp Int64
seg_h TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    -- Maximum group size (or actual, in this case).
    let hist_B :: TExp Int64
hist_B = Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'

    -- Size of a histogram.
    TExp Int64
hist_H <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_H" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (HistOp GPUMem -> TExp Int64) -> [HistOp GPUMem] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (HistOp GPUMem -> SubExp) -> HistOp GPUMem -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histWidth) [HistOp GPUMem]
ops

    -- Size of a single histogram element.  Actually the weighted
    -- average of histogram elements in cases where we have more than
    -- one histogram operation, plus any locks.
    let lockSize :: SegHistSlug -> Maybe a
lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate GPUMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
          AtomicLocking {} -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
int32
          AtomicUpdate GPUMem KernelEnv
_ -> Maybe a
forall a. Maybe a
Nothing
    TExp Int64
hist_el_size <-
      String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_el_size" (TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
        (TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
(+) (TExp Int64
h TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64
hist_H) ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
          (SegHistSlug -> Maybe (TExp Int64))
-> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SegHistSlug -> Maybe (TExp Int64)
forall a. Num a => SegHistSlug -> Maybe a
lockSize [SegHistSlug]
slugs

    -- Input elements contributing to each histogram.
    TExp Int64
hist_N <- String -> TExp Int64 -> ImpM GPUMem HostEnv HostOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_N" TExp Int64
segment_size

    -- Compute RF as the average RF over all the histograms.
    TExp Int32
hist_RF <-
      String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"hist_RF" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
        TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$
          [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> TExp Int64) -> [SegHistSlug] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (SegHistSlug -> SubExp) -> SegHistSlug -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histRaceFactor (HistOp GPUMem -> SubExp)
-> (SegHistSlug -> HistOp GPUMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp GPUMem
slugOp) [SegHistSlug]
slugs)
            TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs

    let hist_T :: TExp Int32
hist_T = TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> TExp Int64 -> TExp Int32
forall a b. (a -> b) -> a -> b
$ Count NumGroups (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count NumGroups (TExp Int64)
num_groups' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegHist" Maybe Exp
forall a. Maybe a
Nothing
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of threads (T)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_T
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Desired group size (B)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_B
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_H
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Input elements per histogram (N)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_N
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
        Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp) -> TExp Int64 -> Exp
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram element size (el_size)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
hist_el_size
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race factor (RF)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int32
hist_RF
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
h
    Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms times segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
seg_h

    (TExp Bool
use_local_memory, CallKernelGen ()
run_in_local_memory) <-
      [PatElem GPUMem]
-> TExp Int32
-> SegSpace
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen (TExp Bool, CallKernelGen ())
localMemoryCase [PatElem GPUMem]
[PatElemT LetDecMem]
map_pes TExp Int32
hist_T SegSpace
space TExp Int64
hist_H TExp Int64
hist_el_size TExp Int64
hist_N TExp Int32
hist_RF [SegHistSlug]
slugs KernelBody GPUMem
kbody

    TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
use_local_memory CallKernelGen ()
run_in_local_memory (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [PatElem GPUMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody GPUMem
-> CallKernelGen ()
histKernelGlobal [PatElem GPUMem]
[PatElemT LetDecMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody GPUMem
kbody

    let pes_per_op :: [[PatElemT LetDecMem]]
pes_per_op = [Int] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp GPUMem -> [VName]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [PatElemT LetDecMem]
all_red_pes

    [(SegHistSlug, [PatElemT LetDecMem], HistOp GPUMem)]
-> ((SegHistSlug, [PatElemT LetDecMem], HistOp GPUMem)
    -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [[PatElemT LetDecMem]]
-> [HistOp GPUMem]
-> [(SegHistSlug, [PatElemT LetDecMem], HistOp GPUMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [[PatElemT LetDecMem]]
pes_per_op [HistOp GPUMem]
ops) (((SegHistSlug, [PatElemT LetDecMem], HistOp GPUMem)
  -> CallKernelGen ())
 -> CallKernelGen ())
-> ((SegHistSlug, [PatElemT LetDecMem], HistOp GPUMem)
    -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, [PatElemT LetDecMem]
red_pes, HistOp GPUMem
op) -> do
      let num_histos :: TV Int64
num_histos = SegHistSlug -> TV Int64
slugNumSubhistos SegHistSlug
slug
          subhistos :: [VName]
subhistos = (SubhistosInfo -> VName) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray ([SubhistosInfo] -> [VName]) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug

      let unitHistoCase :: CallKernelGen ()
unitHistoCase =
            -- This is OK because the memory blocks are at least as
            -- large as the ones we are supposed to use for the result.
            [(PatElemT LetDecMem, VName)]
-> ((PatElemT LetDecMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [VName] -> [(PatElemT LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
red_pes [VName]
subhistos) (((PatElemT LetDecMem, VName) -> CallKernelGen ())
 -> CallKernelGen ())
-> ((PatElemT LetDecMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, VName
subhisto) -> do
              VName
pe_mem <-
                MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
                  (ArrayEntry -> VName)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
              VName
subhisto_mem <-
                MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
                  (ArrayEntry -> VName)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
subhisto
              Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
pe_mem VName
subhisto_mem (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"

      TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_histos TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
1) CallKernelGen ()
unitHistoCase (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        -- For the segmented reduction, we keep the segment dimensions
        -- unchanged.  To this, we add two dimensions: one over the number
        -- of buckets, and one over the number of subhistograms.  This
        -- inner dimension is the one that is collapsed in the reduction.
        let num_buckets :: SubExp
num_buckets = HistOp GPUMem -> SubExp
forall rep. HistOp rep -> SubExp
histWidth HistOp GPUMem
op

        VName
bucket_id <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id"
        VName
subhistogram_id <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"subhistogram_id"
        [VName]
vector_ids <-
          (SubExp -> ImpM GPUMem HostEnv HostOp VName)
-> [SubExp] -> ImpM GPUMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ImpM GPUMem HostEnv HostOp VName
-> SubExp -> ImpM GPUMem HostEnv HostOp VName
forall a b. a -> b -> a
const (ImpM GPUMem HostEnv HostOp VName
 -> SubExp -> ImpM GPUMem HostEnv HostOp VName)
-> ImpM GPUMem HostEnv HostOp VName
-> SubExp
-> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"vector_id") ([SubExp] -> ImpM GPUMem HostEnv HostOp [VName])
-> [SubExp] -> ImpM GPUMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$
            Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op

        VName
flat_gtid <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_gtid"

        let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegVirt
            segred_space :: SegSpace
segred_space =
              VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
                [(VName, SubExp)]
segment_dims
                  [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
bucket_id, SubExp
num_buckets)]
                  [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)
                  [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
num_histos)]

        let segred_op :: SegBinOp GPUMem
segred_op = Commutativity
-> LambdaT GPUMem -> [SubExp] -> Shape -> SegBinOp GPUMem
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Commutative (HistOp GPUMem -> LambdaT GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) (HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp GPUMem
op) Shape
forall a. Monoid a => a
mempty
        Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' ([PatElemT LetDecMem] -> PatT LetDecMem
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT LetDecMem]
red_pes) SegLevel
lvl SegSpace
segred_space [SegBinOp GPUMem
segred_op] (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ->
          [(SubExp, [TExp Int64])] -> InKernelGen ()
red_cont ([(SubExp, [TExp Int64])] -> InKernelGen ())
-> [(SubExp, [TExp Int64])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            ((VName -> (SubExp, [TExp Int64]))
 -> [VName] -> [(SubExp, [TExp Int64])])
-> [VName]
-> (VName -> (SubExp, [TExp Int64]))
-> [(SubExp, [TExp Int64])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos ((VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])])
-> (VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
              ( VName -> SubExp
Var VName
subhisto,
                (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [TExp Int64]) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
                  ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id, VName
bucket_id] [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids
              )

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing
  where
    segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space