{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
module Futhark.CodeGen.ImpGen.Kernels.SegHist
( compileSegHist )
where
import Control.Monad.Except
import Data.Maybe
import Data.List (foldl', genericLength, zip4, zip6)
import Prelude hiding (quot, rem)
import Futhark.MonadFreshNames
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Pass.ExplicitAllocations()
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.SegRed (compileSegRed')
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast)
import Futhark.Construct (fullSliceNum)
data SubhistosInfo = SubhistosInfo { SubhistosInfo -> VName
subhistosArray :: VName
, SubhistosInfo -> CallKernelGen ()
subhistosAlloc :: CallKernelGen ()
}
data SegHistSlug = SegHistSlug
{ SegHistSlug -> HistOp KernelsMem
slugOp :: HistOp KernelsMem
, SegHistSlug -> VName
slugNumSubhistos :: VName
, SegHistSlug -> [SubhistosInfo]
slugSubhistos :: [SubhistosInfo]
, SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate :: AtomicUpdate KernelsMem KernelEnv
}
histoSpaceUsage :: HistOp KernelsMem
-> Imp.Count Imp.Bytes Imp.Exp
histoSpaceUsage :: HistOp KernelsMem -> Count Bytes Exp
histoSpaceUsage HistOp KernelsMem
op =
(Exp -> Exp) -> Count Bytes Exp -> Count Bytes Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int32) (Count Bytes Exp -> Count Bytes Exp)
-> Count Bytes Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> [Count Bytes Exp] -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$
(Type -> Count Bytes Exp) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes Exp
typeSize (Type -> Count Bytes Exp)
-> (Type -> Type) -> Type -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
(Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op) (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
(Type -> Shape -> Type
`arrayOfShape` HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)) ([Type] -> [Count Bytes Exp]) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$
LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT KernelsMem -> [Type]) -> LambdaT KernelsMem -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op
computeHistoUsage :: SegSpace
-> HistOp KernelsMem
-> CallKernelGen (Imp.Count Imp.Bytes Imp.Exp,
Imp.Count Imp.Bytes Imp.Exp,
SegHistSlug)
computeHistoUsage :: SegSpace
-> HistOp KernelsMem
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
computeHistoUsage SegSpace
space HistOp KernelsMem
op = do
let segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_segments :: Int
num_segments = [(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
segment_dims
VName
num_subhistos <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"num_subhistos" PrimType
int32
[SubhistosInfo]
subhisto_infos <- [(VName, SubExp)]
-> ((VName, SubExp)
-> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> ImpM KernelsMem HostEnv HostOp [SubhistosInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp KernelsMem
op) (HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp KernelsMem
op)) (((VName, SubExp) -> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> ImpM KernelsMem HostEnv HostOp [SubhistosInfo])
-> ((VName, SubExp)
-> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> ImpM KernelsMem HostEnv HostOp [SubhistosInfo]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) -> do
Type
dest_t <- VName -> ImpM KernelsMem HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
dest
MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
dest
VName
subhistos_mem <-
String -> Space -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Space -> ImpM lore r op VName
sDeclareMem (VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos_mem") (String -> Space
Space String
"device")
let subhistos_shape :: Shape
subhistos_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
segment_dims[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++[VName -> SubExp
Var VName
num_subhistos]) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<>
Int -> Shape -> Shape
forall a. ArrayShape a => Int -> a -> a
stripDims Int
num_segments (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t)
subhistos_membind :: MemBind
subhistos_membind = VName -> IxFun -> MemBind
ArrayIn VName
subhistos_mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
(SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
subhistos_shape
VName
subhistos <- String
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray (VName -> String
baseString VName
dest String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_subhistos")
(Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t) Shape
subhistos_shape MemBind
subhistos_membind
SubhistosInfo -> ImpM KernelsMem HostEnv HostOp SubhistosInfo
forall (m :: * -> *) a. Monad m => a -> m a
return (SubhistosInfo -> ImpM KernelsMem HostEnv HostOp SubhistosInfo)
-> SubhistosInfo -> ImpM KernelsMem HostEnv HostOp SubhistosInfo
forall a b. (a -> b) -> a -> b
$ VName -> CallKernelGen () -> SubhistosInfo
SubhistosInfo VName
subhistos (CallKernelGen () -> SubhistosInfo)
-> CallKernelGen () -> SubhistosInfo
forall a b. (a -> b) -> a -> b
$ do
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
subhistos_mem (MemLocation -> VName
memLocationName MemLocation
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
String -> Space
Space String
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = do
let num_elems :: Exp
num_elems = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(*) (VName -> PrimType -> Exp
Imp.var VName
num_subhistos PrimType
int32) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
(SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dest_t
let subhistos_mem_size :: Count Bytes Exp
subhistos_mem_size =
Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp) -> Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$
Count Bytes Exp -> Exp
forall u e. Count u e -> e
Imp.unCount (Exp -> Count Elements Exp
Imp.elements Exp
num_elems Count Elements Exp -> PrimType -> Count Bytes Exp
`Imp.withElemType` Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
dest_t)
VName -> Count Bytes Exp -> Space -> CallKernelGen ()
forall lore r op.
VName -> Count Bytes Exp -> Space -> ImpM lore r op ()
sAlloc_ VName
subhistos_mem Count Bytes Exp
subhistos_mem_size (Space -> CallKernelGen ()) -> Space -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"
VName -> SubExp -> CallKernelGen ()
sReplicate VName
subhistos SubExp
ne
Type
subhistos_t <- VName -> ImpM KernelsMem HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
subhistos
let slice :: Slice Exp
slice = [Exp] -> Slice Exp -> Slice Exp
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
subhistos_t) (Slice Exp -> Slice Exp) -> Slice Exp -> Slice Exp
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> DimIndex Exp) -> [(VName, SubExp)] -> Slice Exp
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp -> DimIndex Exp
forall d. Num d => d -> d -> DimIndex d
unitSlice Exp
0 (Exp -> DimIndex Exp)
-> ((VName, SubExp) -> Exp) -> (VName, SubExp) -> DimIndex Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims Slice Exp -> Slice Exp -> Slice Exp
forall a. [a] -> [a] -> [a]
++
[Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
0]
VName -> Slice Exp -> SubExp -> CallKernelGen ()
forall lore r op. VName -> Slice Exp -> SubExp -> ImpM lore r op ()
sUpdate VName
subhistos Slice Exp
slice (SubExp -> CallKernelGen ()) -> SubExp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (VName -> PrimType -> Exp
Imp.var VName
num_subhistos PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
let h :: Count Bytes Exp
h = HistOp KernelsMem -> Count Bytes Exp
histoSpaceUsage HistOp KernelsMem
op
segmented_h :: Count Bytes Exp
segmented_h = Count Bytes Exp
h Count Bytes Exp -> Count Bytes Exp -> Count Bytes Exp
forall a. Num a => a -> a -> a
* [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> Count Bytes Exp) -> [SubExp] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp)
-> (SubExp -> Exp) -> SubExp -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Count Bytes Exp]) -> [SubExp] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM KernelsMem HostEnv HostOp HostEnv
-> ImpM KernelsMem HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv
(Count Bytes Exp, Count Bytes Exp, SegHistSlug)
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
forall (m :: * -> *) a. Monad m => a -> m a
return (Count Bytes Exp
h,
Count Bytes Exp
segmented_h,
HistOp KernelsMem
-> VName
-> [SubhistosInfo]
-> AtomicUpdate KernelsMem KernelEnv
-> SegHistSlug
SegHistSlug HistOp KernelsMem
op VName
num_subhistos [SubhistosInfo]
subhisto_infos (AtomicUpdate KernelsMem KernelEnv -> SegHistSlug)
-> AtomicUpdate KernelsMem KernelEnv -> SegHistSlug
forall a b. (a -> b) -> a -> b
$
AtomicBinOp
-> LambdaT KernelsMem -> AtomicUpdate KernelsMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics (LambdaT KernelsMem -> AtomicUpdate KernelsMem KernelEnv)
-> LambdaT KernelsMem -> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)
prepareAtomicUpdateGlobal :: Maybe Locking -> [VName] -> SegHistSlug
-> CallKernelGen (Maybe Locking,
[Imp.Exp] -> InKernelGen ())
prepareAtomicUpdateGlobal :: Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug =
case (Maybe Locking
l, SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug) of
(Maybe Locking
_, AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
(Maybe Locking
_, AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"global") [VName]
dests)
(Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)
(Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> do
let num_locks :: Int
num_locks = Int
100151
dims :: [Exp]
dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++
[ VName -> SubExp
Var (SegHistSlug -> VName
slugNumSubhistos SegHistSlug
slug)
, HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)]
VName
locks <-
String
-> Space
-> PrimType
-> ArrayContents
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
"hist_locks" (String -> Space
Space String
"device") PrimType
int32 (ArrayContents -> ImpM KernelsMem HostEnv HostOp VName)
-> ArrayContents -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
let l' :: Locking
l' = VName -> Exp -> Exp -> Exp -> ([Exp] -> [Exp]) -> Locking
Locking VName
locks Exp
0 Exp
1 Exp
0 (Exp -> [Exp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> [Exp]) -> ([Exp] -> Exp) -> [Exp] -> [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (Exp -> Exp) -> ([Exp] -> Exp) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> [Exp] -> Exp
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [Exp]
dims)
(Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"global") [VName]
dests)
data Passage = MustBeSinglePass | MayBeMultiPass deriving (Passage -> Passage -> Bool
(Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool) -> Eq Passage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Passage -> Passage -> Bool
$c/= :: Passage -> Passage -> Bool
== :: Passage -> Passage -> Bool
$c== :: Passage -> Passage -> Bool
Eq, Eq Passage
Eq Passage
-> (Passage -> Passage -> Ordering)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Bool)
-> (Passage -> Passage -> Passage)
-> (Passage -> Passage -> Passage)
-> Ord Passage
Passage -> Passage -> Bool
Passage -> Passage -> Ordering
Passage -> Passage -> Passage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Passage -> Passage -> Passage
$cmin :: Passage -> Passage -> Passage
max :: Passage -> Passage -> Passage
$cmax :: Passage -> Passage -> Passage
>= :: Passage -> Passage -> Bool
$c>= :: Passage -> Passage -> Bool
> :: Passage -> Passage -> Bool
$c> :: Passage -> Passage -> Bool
<= :: Passage -> Passage -> Bool
$c<= :: Passage -> Passage -> Bool
< :: Passage -> Passage -> Bool
$c< :: Passage -> Passage -> Bool
compare :: Passage -> Passage -> Ordering
$ccompare :: Passage -> Passage -> Ordering
$cp1Ord :: Eq Passage
Ord)
bodyPassage :: KernelBody KernelsMem -> Passage
bodyPassage :: KernelBody KernelsMem -> Passage
bodyPassage KernelBody KernelsMem
kbody
| Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== KernelBody (Aliases KernelsMem) -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody (KernelBody KernelsMem -> KernelBody (Aliases KernelsMem)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody KernelBody KernelsMem
kbody) =
Passage
MayBeMultiPass
| Bool
otherwise =
Passage
MustBeSinglePass
prepareIntermediateArraysGlobal :: Passage -> Imp.Exp -> Imp.Exp -> [SegHistSlug]
-> CallKernelGen
(Imp.Exp,
[[Imp.Exp] -> InKernelGen ()])
prepareIntermediateArraysGlobal :: Passage
-> Exp
-> Exp
-> [SegHistSlug]
-> CallKernelGen (Exp, [[Exp] -> InKernelGen ()])
prepareIntermediateArraysGlobal Passage
passage Exp
hist_T Exp
hist_N [SegHistSlug]
slugs = do
Exp
hist_H <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> ([Exp] -> Exp) -> [Exp] -> ImpM KernelsMem HostEnv HostOp Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Exp] -> ImpM KernelsMem HostEnv HostOp Exp)
-> ImpM KernelsMem HostEnv HostOp [Exp]
-> ImpM KernelsMem HostEnv HostOp Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SegHistSlug -> ImpM KernelsMem HostEnv HostOp Exp)
-> [SegHistSlug] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> (SegHistSlug -> SubExp)
-> SegHistSlug
-> ImpM KernelsMem HostEnv HostOp Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs
Exp
hist_RF <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_RF" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
[Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (Exp -> Exp) -> (SegHistSlug -> Exp) -> SegHistSlug -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> (SegHistSlug -> SubExp) -> SegHistSlug -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histRaceFactor (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs)
Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 ([SegHistSlug] -> Exp
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs)
Exp
hist_el_size <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_el_size" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Exp
slugElAvgSize [SegHistSlug]
slugs
Exp
hist_C_max <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_C_max" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMin FloatType
Float64) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_H Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_k_ct_min
Exp
hist_M_min <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_M_min" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_C_max
let hist_L2_def :: Int32
hist_L2_def = Int32
4 Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
1024 Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
1024
VName
hist_L2 <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"L2_size" PrimType
int32
Maybe Name
entry <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
hist_L2
(Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
entry (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (VName -> String
forall a. Pretty a => a -> String
pretty VName
hist_L2)) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
Name -> Int32 -> SizeClass
Imp.SizeBespoke (String -> Name
nameFromString String
"L2_for_histogram") Int32
hist_L2_def
let hist_L2_ln_sz :: Exp
hist_L2_ln_sz = Exp
16Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
4
Exp
hist_RACE_exp <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_RACE_exp" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMax FloatType
Float64) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
(Exp
hist_k_RF Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_RF) Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/
(Exp
hist_L2_ln_sz Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_el_size)
VName
hist_S <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"hist_S" PrimType
int32
Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
hist_N Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
hist_H)
(VName
hist_S VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
hist_S VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<--
case Passage
passage of
Passage
MayBeMultiPass ->
(Exp
hist_M_min Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_el_size) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp`
Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 (Exp
hist_F_L2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (VName -> Exp
Imp.vi32 VName
hist_L2) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_RACE_exp)
Passage
MustBeSinglePass ->
Exp
1
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race expansion factor (RACE^exp)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_RACE_exp
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_S
[[Exp] -> InKernelGen ()]
histograms <- (Maybe Locking, [[Exp] -> InKernelGen ()])
-> [[Exp] -> InKernelGen ()]
forall a b. (a, b) -> b
snd ((Maybe Locking, [[Exp] -> InKernelGen ()])
-> [[Exp] -> InKernelGen ()])
-> ImpM
KernelsMem
HostEnv
HostOp
(Maybe Locking, [[Exp] -> InKernelGen ()])
-> ImpM KernelsMem HostEnv HostOp [[Exp] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ()))
-> Maybe Locking
-> [SegHistSlug]
-> ImpM
KernelsMem
HostEnv
HostOp
(Maybe Locking, [[Exp] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (Exp
-> Exp
-> Exp
-> Exp
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
onOp (VName -> Exp
Imp.vi32 VName
hist_L2) Exp
hist_M_min (VName -> Exp
Imp.vi32 VName
hist_S) Exp
hist_RACE_exp) Maybe Locking
forall a. Maybe a
Nothing [SegHistSlug]
slugs
(Exp, [[Exp] -> InKernelGen ()])
-> CallKernelGen (Exp, [[Exp] -> InKernelGen ()])
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Exp
Imp.vi32 VName
hist_S, [[Exp] -> InKernelGen ()]
histograms)
where
hist_k_ct_min :: Exp
hist_k_ct_min = Exp
2
hist_k_RF :: Exp
hist_k_RF = Exp
0.75
hist_F_L2 :: Exp
hist_F_L2 = Exp
0.4
r64 :: PrimExp v -> PrimExp v
r64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64)
t64 :: PrimExp v -> PrimExp v
t64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int32)
slugElAvgSize :: SegHistSlug -> Exp
slugElAvgSize slug :: SegHistSlug
slug@(SegHistSlug HistOp KernelsMem
op VName
_ [SubhistosInfo]
_ AtomicUpdate KernelsMem KernelEnv
do_op) =
case AtomicUpdate KernelsMem KernelEnv
do_op of
AtomicLocking{} ->
SegHistSlug -> Exp
slugElSize SegHistSlug
slug Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` (Exp
1Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+[Type] -> Exp
forall i a. Num i => [a] -> i
genericLength (LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)))
AtomicUpdate KernelsMem KernelEnv
_ ->
SegHistSlug -> Exp
slugElSize SegHistSlug
slug Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` [Type] -> Exp
forall i a. Num i => [a] -> i
genericLength (LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op))
slugElSize :: SegHistSlug -> Exp
slugElSize (SegHistSlug HistOp KernelsMem
op VName
_ [SubhistosInfo]
_ AtomicUpdate KernelsMem KernelEnv
do_op) =
case AtomicUpdate KernelsMem KernelEnv
do_op of
AtomicLocking{} ->
IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int32 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$
[Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> [Count Bytes Exp] -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ (Type -> Count Bytes Exp) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes Exp
typeSize (Type -> Count Bytes Exp)
-> (Type -> Type) -> Type -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)) ([Type] -> [Count Bytes Exp]) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$
PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)
AtomicUpdate KernelsMem KernelEnv
_ ->
IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int32 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes Exp] -> Count Bytes Exp)
-> [Count Bytes Exp] -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$
(Type -> Count Bytes Exp) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Count Bytes Exp
typeSize (Type -> Count Bytes Exp)
-> (Type -> Type) -> Type -> Count Bytes Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type
`arrayOfShape` HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op)) ([Type] -> [Count Bytes Exp]) -> [Type] -> [Count Bytes Exp]
forall a b. (a -> b) -> a -> b
$
LambdaT KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op)
onOp :: Exp
-> Exp
-> Exp
-> Exp
-> Maybe Locking
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
onOp Exp
hist_L2 Exp
hist_M_min Exp
hist_S Exp
hist_RACE_exp Maybe Locking
l SegHistSlug
slug = do
let SegHistSlug HistOp KernelsMem
op VName
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate KernelsMem KernelEnv
do_op = SegHistSlug
slug
Exp
hist_H <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op
Exp
hist_H_chk <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H_chk" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
hist_H Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
hist_S
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Chunk size (H_chk)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_H_chk
Exp
hist_k_max <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_k_max" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMin FloatType
Float64)
(Exp
hist_F_L2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_L2 Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (SegHistSlug -> Exp
slugElSize SegHistSlug
slug)) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_RACE_exp)
(Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_N)
Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T
Exp
hist_u <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_u" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
case AtomicUpdate KernelsMem KernelEnv
do_op of
AtomicPrim{} -> Exp
2
AtomicUpdate KernelsMem KernelEnv
_ -> Exp
1
Exp
hist_C <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_C" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (FloatType -> BinOp
FMin FloatType
Float64) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T) (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (Exp
hist_u Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk) Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_k_max
Exp
hist_M <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_M" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
case SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim{} -> Exp
1
AtomicUpdate KernelsMem KernelEnv
_ -> BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
hist_M_min (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_T Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp
hist_C
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Elements/thread in L2 cache (k_max)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_k_max
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_M
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_C
VName
num_subhistos VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
hist_M
[VName]
dests <- [(VName, SubhistosInfo)]
-> ((VName, SubhistosInfo) -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubhistosInfo] -> [(VName, SubhistosInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip (HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp KernelsMem
op) [SubhistosInfo]
subhisto_info) (((VName, SubhistosInfo) -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName])
-> ((VName, SubhistosInfo) -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubhistosInfo
info) -> do
MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
dest
VName
sub_mem <- (MemLocation -> VName)
-> ImpM KernelsMem HostEnv HostOp MemLocation
-> ImpM KernelsMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemLocation -> VName
memLocationName (ImpM KernelsMem HostEnv HostOp MemLocation
-> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp MemLocation
-> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (SubhistosInfo -> VName
subhistosArray SubhistosInfo
info)
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
sub_mem (MemLocation -> VName
memLocationName MemLocation
dest_mem) (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$
String -> Space
Space String
"device"
multiHistoCase :: CallKernelGen ()
multiHistoCase = SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
hist_M Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) CallKernelGen ()
unitHistoCase CallKernelGen ()
multiHistoCase
VName -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM KernelsMem HostEnv HostOp VName)
-> VName -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
(Maybe Locking
l', [Exp] -> InKernelGen ()
do_op') <- Maybe Locking
-> [VName]
-> SegHistSlug
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
prepareAtomicUpdateGlobal Maybe Locking
l [VName]
dests SegHistSlug
slug
(Maybe Locking, [Exp] -> InKernelGen ())
-> CallKernelGen (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l', [Exp] -> InKernelGen ()
do_op')
histKernelGlobalPass :: [PatElem KernelsMem]
-> Count NumGroups Imp.Exp
-> Count GroupSize Imp.Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> [[Imp.Exp] -> InKernelGen ()]
-> Imp.Exp -> Imp.Exp
-> CallKernelGen ()
histKernelGlobalPass :: [PatElem KernelsMem]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> [[Exp] -> InKernelGen ()]
-> Exp
-> Exp
-> CallKernelGen ()
histKernelGlobalPass [PatElem KernelsMem]
map_pes Count NumGroups Exp
num_groups Count GroupSize Exp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody [[Exp] -> InKernelGen ()]
histograms Exp
hist_S Exp
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
space_sizes_64 :: [Exp]
space_sizes_64 = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 (Exp -> Exp) -> (SubExp -> Exp) -> SubExp -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
space_sizes
total_w_64 :: Exp
total_w_64 = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
space_sizes_64
[Exp]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> ImpM KernelsMem HostEnv HostOp [Exp])
-> (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> ImpM KernelsMem HostEnv HostOp [Exp]
forall a b. (a -> b) -> a -> b
$ \SubExp
w -> do
Exp
w' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
w
String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H_chk" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
w' Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
hist_S
String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_global" Count NumGroups Exp
num_groups Count GroupSize Exp
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
[Exp]
subhisto_inds <- [SegHistSlug]
-> (SegHistSlug -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegHistSlug]
slugs ((SegHistSlug -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> ImpM KernelsMem KernelEnv KernelOp [Exp])
-> (SegHistSlug -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug ->
String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"subhisto_ind" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot`
(KernelConstants -> Exp
kernelNumThreads KernelConstants
constants Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` VName -> Exp
Imp.vi32 (SegHistSlug -> VName
slugNumSubhistos SegHistSlug
slug))
let gtid :: Exp
gtid = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
num_threads :: Exp
num_threads = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelNumThreads KernelConstants
constants
Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop Exp
gtid Exp
num_threads Exp
total_w_64 ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
offset -> do
let setIndex :: VName -> Exp -> ImpM lore r op ()
setIndex VName
v Exp
e = do VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
v PrimType
int32
VName
v VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
e
(VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
setIndex [VName]
space_is ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int32) ([Exp] -> [Exp]) -> [Exp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
space_sizes_64 Exp
offset
let input_in_bounds :: Exp
input_in_bounds = Exp
offset Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
total_w_64
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
input_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, KernelResult)]
-> ((PatElemT LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem]
-> [KernelResult] -> [(PatElemT LetDecMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes [KernelResult]
map_res) (((PatElemT LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, KernelResult
res) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
(((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Exp
Imp.vi32 (VName -> Exp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
let ([KernelResult]
buckets, [KernelResult]
vs) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [KernelResult]
red_res
perOp :: [KernelResult] -> [[KernelResult]]
perOp = [Int] -> [KernelResult] -> [[KernelResult]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [KernelResult] -> [[KernelResult]])
-> [Int] -> [KernelResult] -> [[KernelResult]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest (HistOp KernelsMem -> [VName])
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
[KernelResult], Exp, Exp)]
-> ((HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
[KernelResult], Exp, Exp)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp KernelsMem]
-> [[Exp] -> InKernelGen ()]
-> [KernelResult]
-> [[KernelResult]]
-> [Exp]
-> [Exp]
-> [(HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
[KernelResult], Exp, Exp)]
forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 ((SegHistSlug -> HistOp KernelsMem)
-> [SegHistSlug] -> [HistOp KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp KernelsMem
slugOp [SegHistSlug]
slugs) [[Exp] -> InKernelGen ()]
histograms [KernelResult]
buckets ([KernelResult] -> [[KernelResult]]
perOp [KernelResult]
vs) [Exp]
subhisto_inds [Exp]
hist_H_chks) (((HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
[KernelResult], Exp, Exp)
-> InKernelGen ())
-> InKernelGen ())
-> ((HistOp KernelsMem, [Exp] -> InKernelGen (), KernelResult,
[KernelResult], Exp, Exp)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT KernelsMem
lam,
[Exp] -> InKernelGen ()
do_op, KernelResult
bucket, [KernelResult]
vs', Exp
subhisto_ind, Exp
hist_H_chk) -> do
let chk_beg :: Exp
chk_beg = Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk
bucket' :: Exp
bucket' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> SubExp -> Exp
forall a b. (a -> b) -> a -> b
$ KernelResult -> SubExp
kernelResultSubExp KernelResult
bucket
dest_w' :: Exp
dest_w' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
dest_w
bucket_in_bounds :: Exp
bucket_in_bounds = Exp
chk_beg Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. (Exp
chk_beg Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
hist_H_chk) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
dest_w'
vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
vs') ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let bucket_is :: [Exp]
bucket_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is) [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++
[Exp
subhisto_ind, Exp
bucket']
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is -> do
[(Param LetDecMem, KernelResult)]
-> ((Param LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [KernelResult] -> [(Param LetDecMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [KernelResult]
vs') (((Param LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LetDecMem, KernelResult) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, KernelResult
res) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) [Exp]
is
[Exp] -> InKernelGen ()
do_op ([Exp]
bucket_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)
histKernelGlobal :: [PatElem KernelsMem]
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelGlobal :: [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelGlobal [PatElem KernelsMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody = do
Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
let ([VName]
_space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_threads :: Exp
num_threads = Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using global memory" Maybe Exp
forall a. Maybe a
Nothing
(Exp
hist_S, [[Exp] -> InKernelGen ()]
histograms) <-
Passage
-> Exp
-> Exp
-> [SegHistSlug]
-> CallKernelGen (Exp, [[Exp] -> InKernelGen ()])
prepareIntermediateArraysGlobal (KernelBody KernelsMem -> Passage
bodyPassage KernelBody KernelsMem
kbody)
Exp
num_threads (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> SubExp -> Exp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes) [SegHistSlug]
slugs
String -> Exp -> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"chk_i" Exp
hist_S ((Exp -> CallKernelGen ()) -> CallKernelGen ())
-> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
chk_i ->
[PatElem KernelsMem]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> [[Exp] -> InKernelGen ()]
-> Exp
-> Exp
-> CallKernelGen ()
histKernelGlobalPass [PatElem KernelsMem]
map_pes Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody
[[Exp] -> InKernelGen ()]
histograms Exp
hist_S Exp
chk_i
type InitLocalHistograms = [([VName],
SubExp ->
InKernelGen ([VName],
[Imp.Exp] -> InKernelGen ()))]
prepareIntermediateArraysLocal :: VName
-> Count NumGroups Imp.Exp
-> SegSpace -> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal :: VName
-> Count NumGroups Exp
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal VName
num_subhistos_per_group Count NumGroups Exp
groups_per_segment SegSpace
space [SegHistSlug]
slugs = do
Exp
num_segments <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"num_segments" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
[Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
(SegHistSlug
-> ImpM
KernelsMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())))
-> [SegHistSlug] -> CallKernelGen InitLocalHistograms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Exp
-> SegHistSlug
-> ImpM
KernelsMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
onOp Exp
num_segments) [SegHistSlug]
slugs
where
onOp :: Exp
-> SegHistSlug
-> ImpM
KernelsMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
onOp Exp
num_segments (SegHistSlug HistOp KernelsMem
op VName
num_subhistos [SubhistosInfo]
subhisto_info AtomicUpdate KernelsMem KernelEnv
do_op) = do
VName
num_subhistos VName -> Exp -> CallKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_segments
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of subhistograms in global memory" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
num_subhistos
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
mk_op <-
case AtomicUpdate KernelsMem KernelEnv
do_op of
AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
f -> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv)))
-> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. a -> b -> a
const (ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate KernelsMem KernelEnv
f
AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
f -> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv)))
-> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. a -> b -> a
const (ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
-> SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return DoAtomicUpdate KernelsMem KernelEnv
f
AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f -> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall (m :: * -> *) a. Monad m => a -> m a
return ((SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv)))
-> (SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> ImpM
KernelsMem
HostEnv
HostOp
(SubExp
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
forall a b. (a -> b) -> a -> b
$ \SubExp
hist_H_chk -> do
let lock_shape :: Shape
lock_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
num_subhistos_per_group SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++
[SubExp
hist_H_chk]
[Exp]
dims <- (SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp])
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
lock_shape
VName
locks <- String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"locks" PrimType
int32 Shape
lock_shape (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
dims (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
locks [Exp]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall (m :: * -> *) a. Monad m => a -> m a
return (DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem
KernelEnv
KernelOp
(DoAtomicUpdate KernelsMem KernelEnv))
-> DoAtomicUpdate KernelsMem KernelEnv
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
forall a b. (a -> b) -> a -> b
$ Locking -> DoAtomicUpdate KernelsMem KernelEnv
f (Locking -> DoAtomicUpdate KernelsMem KernelEnv)
-> Locking -> DoAtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Exp -> Exp -> ([Exp] -> [Exp]) -> Locking
Locking VName
locks Exp
0 Exp
1 Exp
0 [Exp] -> [Exp]
forall a. a -> a
id
let init_local_subhistos :: SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
init_local_subhistos SubExp
hist_H_chk = do
[VName]
local_subhistos <-
[Type]
-> (Type -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp KernelsMem -> [Type]
forall lore. HistOp lore -> [Type]
histType HistOp KernelsMem
op) ((Type -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName])
-> (Type -> ImpM KernelsMem KernelEnv KernelOp VName)
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
let sub_local_shape :: Shape
sub_local_shape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_subhistos_per_group] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<>
(Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t Shape -> SubExp -> Shape
forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` SubExp
hist_H_chk)
String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"subhistogram_local"
(Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) Shape
sub_local_shape (String -> Space
Space String
"local")
DoAtomicUpdate KernelsMem KernelEnv
do_op' <- SubExp
-> ImpM
KernelsMem KernelEnv KernelOp (DoAtomicUpdate KernelsMem KernelEnv)
mk_op SubExp
hist_H_chk
([VName], [Exp] -> InKernelGen ())
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
local_subhistos, DoAtomicUpdate KernelsMem KernelEnv
do_op' (String -> Space
Space String
"local") [VName]
local_subhistos)
[VName]
glob_subhistos <- [SubhistosInfo]
-> (SubhistosInfo -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubhistosInfo]
subhisto_info ((SubhistosInfo -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName])
-> (SubhistosInfo -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubhistosInfo
info -> do
SubhistosInfo -> CallKernelGen ()
subhistosAlloc SubhistosInfo
info
VName -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> ImpM KernelsMem HostEnv HostOp VName)
-> VName -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ SubhistosInfo -> VName
subhistosArray SubhistosInfo
info
([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
-> ImpM
KernelsMem
HostEnv
HostOp
([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
glob_subhistos, SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
init_local_subhistos)
histKernelLocalPass :: VName -> Count NumGroups Imp.Exp
-> [PatElem KernelsMem]
-> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> InitLocalHistograms -> Imp.Exp -> Imp.Exp
-> CallKernelGen ()
histKernelLocalPass :: VName
-> Count NumGroups Exp
-> [PatElem KernelsMem]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> InitLocalHistograms
-> Exp
-> Exp
-> CallKernelGen ()
histKernelLocalPass VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment [PatElem KernelsMem]
map_pes Count NumGroups Exp
num_groups Count GroupSize Exp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody
InitLocalHistograms
init_histograms Exp
hist_S Exp
chk_i = do
let ([VName]
space_is, [SubExp]
space_sizes) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
segment_is :: [VName]
segment_is = [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
space_is
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
(VName
i_in_segment, SubExp
segment_size) = [(VName, SubExp)] -> (VName, SubExp)
forall a. [a] -> a
last ([(VName, SubExp)] -> (VName, SubExp))
-> [(VName, SubExp)] -> (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
num_subhistos_per_group :: Exp
num_subhistos_per_group = VName -> PrimType -> Exp
Imp.var VName
num_subhistos_per_group_var PrimType
int32
Exp
segment_size' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
segment_size
Exp
num_segments <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"num_segments" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
[Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
segment_dims
[VName]
hist_H_chks <- [SubExp]
-> (SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((SegHistSlug -> SubExp) -> [SegHistSlug] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs) ((SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName])
-> (SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$ \SubExp
w -> do
Exp
w' <- SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
w
String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"hist_H_chk" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Exp
w' Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
hist_S
String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"seghist_local" Count NumGroups Exp
num_groups Count GroupSize Exp
group_size (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt (Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_segments) ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
group_id_var -> do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let group_id :: Exp
group_id = VName -> Exp
Imp.vi32 VName
group_id_var
Exp
flat_segment_id <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"flat_segment_id" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_id Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment
Exp
gid_in_segment <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"gid_in_segment" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_id Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment
Exp
pgtid_in_segment <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"pgtid_in_segment" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
gid_in_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
Exp
threads_per_segment <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"threads_per_segment" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
(VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
segment_is ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
segment_dims) Exp
flat_segment_id
[([(VName, VName)], VName, [Exp] -> InKernelGen ())]
histograms <- [(([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
VName)]
-> ((([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
VName)
-> ImpM
KernelsMem
KernelEnv
KernelOp
([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> ImpM
KernelsMem
KernelEnv
KernelOp
[([(VName, VName)], VName, [Exp] -> InKernelGen ())]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (InitLocalHistograms
-> [VName]
-> [(([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip InitLocalHistograms
init_histograms [VName]
hist_H_chks) (((([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
VName)
-> ImpM
KernelsMem
KernelEnv
KernelOp
([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> ImpM
KernelsMem
KernelEnv
KernelOp
[([(VName, VName)], VName, [Exp] -> InKernelGen ())])
-> ((([VName],
SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())),
VName)
-> ImpM
KernelsMem
KernelEnv
KernelOp
([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> ImpM
KernelsMem
KernelEnv
KernelOp
[([(VName, VName)], VName, [Exp] -> InKernelGen ())]
forall a b. (a -> b) -> a -> b
$
\(([VName]
glob_subhistos, SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
init_local_subhistos), VName
hist_H_chk) -> do
([VName]
local_subhistos, [Exp] -> InKernelGen ()
do_op) <- SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
init_local_subhistos (SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ()))
-> SubExp
-> ImpM
KernelsMem KernelEnv KernelOp ([VName], [Exp] -> InKernelGen ())
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
hist_H_chk
([(VName, VName)], VName, [Exp] -> InKernelGen ())
-> ImpM
KernelsMem
KernelEnv
KernelOp
([(VName, VName)], VName, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
glob_subhistos [VName]
local_subhistos, VName
hist_H_chk, [Exp] -> InKernelGen ()
do_op)
Exp
thread_local_subhisto_i <-
String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"thread_local_subhisto_i" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
num_subhistos_per_group
let onSlugs :: (SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
onSlugs SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ()
f = [(SegHistSlug, ([(VName, VName)], VName, [Exp] -> InKernelGen ()))]
-> ((SegHistSlug,
([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
-> [(SegHistSlug,
([(VName, VName)], VName, [Exp] -> InKernelGen ()))]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegHistSlug]
slugs [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
histograms) (((SegHistSlug, ([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> InKernelGen ())
-> InKernelGen ())
-> ((SegHistSlug,
([(VName, VName)], VName, [Exp] -> InKernelGen ()))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, ([(VName, VName)]
dests, VName
hist_H_chk, [Exp] -> InKernelGen ()
_)) -> do
let histo_dims :: [Exp]
histo_dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
hist_H_chk SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug))
Exp
histo_size <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"histo_size" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
histo_dims
SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ()
f SegHistSlug
slug [(VName, VName)]
dests (VName -> Exp
Imp.vi32 VName
hist_H_chk) [Exp]
histo_dims Exp
histo_size
let onAllHistograms :: (VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ()
f =
(SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ())
-> (SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests Exp
hist_H_chk [Exp]
histo_dims Exp
histo_size -> do
let group_hists_size :: Exp
group_hists_size = Exp
num_subhistos_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
histo_size
Exp
init_per_thread <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"init_per_thread" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
group_hists_size Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp`
KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
[((VName, VName), SubExp)]
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, VName)] -> [SubExp] -> [((VName, VName), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, VName)]
dests (HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral (HistOp KernelsMem -> [SubExp]) -> HistOp KernelsMem -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)) ((((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (((VName, VName), SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\((VName
dest_global, VName
dest_local), SubExp
ne) ->
String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"local_i" Exp
init_per_thread ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
Exp
j <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"j" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
Exp
j_offset <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"j_offset" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
num_subhistos_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
histo_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
gid_in_segment Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
j
Exp
local_subhisto_i <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"local_subhisto_i" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
j Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
histo_size
let local_bucket_is :: [Exp]
local_bucket_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
histo_dims (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ Exp
j Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
histo_size
global_bucket_is :: [Exp]
global_bucket_is = [Exp] -> Exp
forall a. [a] -> a
head [Exp]
local_bucket_is Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:
[Exp] -> [Exp]
forall a. [a] -> [a]
tail [Exp]
local_bucket_is
Exp
global_subhisto_i <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"global_subhisto_i" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
j_offset Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
histo_size
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
j Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
group_hists_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ()
f VName
dest_local VName
dest_global (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug) SubExp
ne
Exp
local_subhisto_i Exp
global_subhisto_i
[Exp]
local_bucket_is [Exp]
global_bucket_is
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"initialize histograms in local memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ())
-> InKernelGen ()
onAllHistograms ((VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ())
-> InKernelGen ())
-> (VName
-> VName
-> HistOp KernelsMem
-> SubExp
-> Exp
-> Exp
-> [Exp]
-> [Exp]
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
dest_local VName
dest_global HistOp KernelsMem
op SubExp
ne Exp
local_subhisto_i Exp
global_subhisto_i [Exp]
local_bucket_is [Exp]
global_bucket_is ->
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"First subhistogram is initialised from global memory; others with neutral element." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let global_is :: [Exp]
global_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
segment_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp
0] [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
global_bucket_is
local_is :: [Exp]
local_is = Exp
local_subhisto_i Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
local_bucket_is
Exp -> InKernelGen () -> InKernelGen () -> InKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
global_subhisto_i Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0)
(VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest_local [Exp]
local_is (VName -> SubExp
Var VName
dest_global) [Exp]
global_is)
(Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest_local ([Exp]
local_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++[Exp]
is) SubExp
ne [])
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop Exp
pgtid_in_segment Exp
threads_per_segment Exp
segment_size' ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
ie -> do
VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
i_in_segment Exp
ie
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
red_res, [SubExp]
map_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElemT LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
([SubExp]
buckets, [SubExp]
vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegHistSlug] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegHistSlug]
slugs) [SubExp]
red_res
perOp :: [SubExp] -> [[SubExp]]
perOp = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [SubExp] -> [[SubExp]])
-> [Int] -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (SegHistSlug -> [VName]) -> SegHistSlug -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest (HistOp KernelsMem -> [VName])
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
chk_i Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save map-out results" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElemT LetDecMem, SubExp)]
-> ((PatElemT LetDecMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [SubExp] -> [(PatElemT LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes [SubExp]
map_res) (((PatElemT LetDecMem, SubExp) -> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT LetDecMem, SubExp) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, SubExp
se) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
space_is) SubExp
se []
[(HistOp KernelsMem,
([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
[SubExp])]
-> ((HistOp KernelsMem,
([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
[SubExp])
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp KernelsMem]
-> [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
-> [SubExp]
-> [[SubExp]]
-> [(HistOp KernelsMem,
([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
[SubExp])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 ((SegHistSlug -> HistOp KernelsMem)
-> [SegHistSlug] -> [HistOp KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> HistOp KernelsMem
slugOp [SegHistSlug]
slugs) [([(VName, VName)], VName, [Exp] -> InKernelGen ())]
histograms [SubExp]
buckets ([SubExp] -> [[SubExp]]
perOp [SubExp]
vs)) (((HistOp KernelsMem,
([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
[SubExp])
-> InKernelGen ())
-> InKernelGen ())
-> ((HistOp KernelsMem,
([(VName, VName)], VName, [Exp] -> InKernelGen ()), SubExp,
[SubExp])
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape LambdaT KernelsMem
lam,
([(VName, VName)]
_, VName
hist_H_chk, [Exp] -> InKernelGen ()
do_op), SubExp
bucket, [SubExp]
vs') -> do
let chk_beg :: Exp
chk_beg = Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
Imp.vi32 VName
hist_H_chk
bucket' :: Exp
bucket' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
bucket
dest_w' :: Exp
dest_w' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
dest_w
bucket_in_bounds :: Exp
bucket_in_bounds = Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
dest_w' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
Exp
chk_beg Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
Exp
bucket' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. (Exp
chk_beg Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
Imp.vi32 VName
hist_H_chk)
bucket_is :: [Exp]
bucket_is = [Exp
thread_local_subhisto_i, Exp
bucket' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
chk_beg]
vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
bucket_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT KernelsMem
lam
Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is -> do
[(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
vs') (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [Exp]
is
[Exp] -> InKernelGen ()
do_op ([Exp]
bucket_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceGlobal
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Compact the multiple local memory subhistograms to result in global memory" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
onSlugs ((SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ())
-> (SegHistSlug
-> [(VName, VName)] -> Exp -> [Exp] -> Exp -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegHistSlug
slug [(VName, VName)]
dests Exp
hist_H_chk [Exp]
histo_dims Exp
histo_size -> do
Exp
bins_per_thread <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"init_per_thread" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
histo_size Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
VName
trunc_H <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"trunc_H" (Exp -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) Exp
hist_H_chk (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug)) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
-
Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* [Exp] -> Exp
forall a. [a] -> a
head [Exp]
histo_dims
let trunc_histo_dims :: [Exp]
trunc_histo_dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
trunc_H SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape (SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug))
Exp
trunc_histo_size <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"histo_size" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
trunc_histo_dims
String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"local_i" Exp
bins_per_thread ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
Exp
j <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"j" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
j Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
trunc_histo_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let local_bucket_is :: [Exp]
local_bucket_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
histo_dims Exp
j
global_bucket_is :: [Exp]
global_bucket_is = [Exp] -> Exp
forall a. [a] -> a
head [Exp]
local_bucket_is Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
chk_i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H_chk Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:
[Exp] -> [Exp]
forall a. [a] -> [a]
tail [Exp]
local_bucket_is
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT KernelsMem -> [LParam KernelsMem])
-> LambdaT KernelsMem -> [LParam KernelsMem]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp (HistOp KernelsMem -> LambdaT KernelsMem)
-> HistOp KernelsMem -> LambdaT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug
let ([VName]
global_dests, [VName]
local_dests) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, VName)]
dests
([Param LetDecMem]
xparams, [Param LetDecMem]
yparams) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
local_dests) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$
LambdaT KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams (LambdaT KernelsMem -> [LParam KernelsMem])
-> LambdaT KernelsMem -> [LParam KernelsMem]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp (HistOp KernelsMem -> LambdaT KernelsMem)
-> HistOp KernelsMem -> LambdaT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Read values from subhistogram 0." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
xparams [VName]
local_dests) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
xp, VName
subhisto) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix
(Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
xp) []
(VName -> SubExp
Var VName
subhisto) (Exp
0Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:[Exp]
local_bucket_is)
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Accumulate based on values in other subhistograms." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"subhisto_id" (Exp
num_subhistos_per_group Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1) ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
subhisto_id -> do
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
yparams [VName]
local_dests) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
yp, VName
subhisto) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix
(Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
yp) []
(VName -> SubExp
Var VName
subhisto) (Exp
subhisto_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1 Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
local_bucket_is)
[Param LetDecMem] -> Body KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetDecMem]
xparams (Body KernelsMem -> InKernelGen ())
-> Body KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ LambdaT KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody (LambdaT KernelsMem -> Body KernelsMem)
-> LambdaT KernelsMem -> Body KernelsMem
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp (HistOp KernelsMem -> LambdaT KernelsMem)
-> HistOp KernelsMem -> LambdaT KernelsMem
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> HistOp KernelsMem
slugOp SegHistSlug
slug
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"Put final bucket value in global memory." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let global_is :: [Exp]
global_is =
(VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
segment_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++
[Exp
group_id Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment] [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++
[Exp]
global_bucket_is
[(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
xparams [VName]
global_dests) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
xp, VName
global_dest) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
global_dest [Exp]
global_is (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
xp) []
histKernelLocal :: VName -> Count NumGroups Imp.Exp
-> [PatElem KernelsMem]
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> SegSpace
-> Imp.Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelLocal :: VName
-> Count NumGroups Exp
-> [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelLocal VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment [PatElem KernelsMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space Exp
hist_S [SegHistSlug]
slugs KernelBody KernelsMem
kbody = do
Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
let num_subhistos_per_group :: Exp
num_subhistos_per_group = VName -> PrimType -> Exp
Imp.var VName
num_subhistos_per_group_var PrimType
int32
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of local subhistograms per group" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
num_subhistos_per_group
InitLocalHistograms
init_histograms <-
VName
-> Count NumGroups Exp
-> SegSpace
-> [SegHistSlug]
-> CallKernelGen InitLocalHistograms
prepareIntermediateArraysLocal VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment SegSpace
space [SegHistSlug]
slugs
String -> Exp -> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"chk_i" Exp
hist_S ((Exp -> CallKernelGen ()) -> CallKernelGen ())
-> (Exp -> CallKernelGen ()) -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
chk_i ->
VName
-> Count NumGroups Exp
-> [PatElem KernelsMem]
-> Count NumGroups Exp
-> Count GroupSize Exp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> InitLocalHistograms
-> Exp
-> Exp
-> CallKernelGen ()
histKernelLocalPass
VName
num_subhistos_per_group_var Count NumGroups Exp
groups_per_segment [PatElem KernelsMem]
map_pes Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody
InitLocalHistograms
init_histograms Exp
hist_S Exp
chk_i
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses :: SegHistSlug -> Int
slugMaxLocalMemPasses SegHistSlug
slug =
case SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
_ -> Int
3
AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
_ -> Int
4
AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
_ -> Int
6
localMemoryCase :: [PatElem KernelsMem]
-> Imp.Exp
-> SegSpace
-> Imp.Exp -> Imp.Exp -> Imp.Exp -> Imp.Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen (Imp.Exp, CallKernelGen ())
localMemoryCase :: [PatElem KernelsMem]
-> Exp
-> SegSpace
-> Exp
-> Exp
-> Exp
-> Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen (Exp, CallKernelGen ())
localMemoryCase [PatElem KernelsMem]
map_pes Exp
hist_T SegSpace
space Exp
hist_H Exp
hist_el_size Exp
hist_N Exp
_ [SegHistSlug]
slugs KernelBody KernelsMem
kbody = do
let space_sizes :: [SubExp]
space_sizes = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
space_sizes
segmented :: Bool
segmented = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
segment_dims
VName
hist_L <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"hist_L" PrimType
int32
HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax VName
hist_L SizeClass
Imp.SizeLocalMemory
VName
max_group_size <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"max_group_size" PrimType
int32
HostOp -> CallKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax VName
max_group_size SizeClass
Imp.SizeGroup
let group_size :: Count GroupSize SubExp
group_size = SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
max_group_size
Count NumGroups SubExp
num_groups <- (VName -> Count NumGroups SubExp)
-> ImpM KernelsMem HostEnv HostOp VName
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Imp.Count (SubExp -> Count NumGroups SubExp)
-> (VName -> SubExp) -> VName -> Count NumGroups SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) (ImpM KernelsMem HostEnv HostOp VName
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp))
-> ImpM KernelsMem HostEnv HostOp VName
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups SubExp)
forall a b. (a -> b) -> a -> b
$ String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_groups" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
Exp
hist_T Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
let num_groups' :: Count NumGroups Exp
num_groups' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> Count NumGroups SubExp -> Count NumGroups Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count NumGroups SubExp
num_groups
group_size' :: Count GroupSize Exp
group_size' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> Count GroupSize SubExp -> Count GroupSize Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count GroupSize SubExp
group_size
let r64 :: PrimExp v -> PrimExp v
r64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
Float64)
t64 :: PrimExp v -> PrimExp v
t64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> IntType -> ConvOp
FPToSI FloatType
Float64 IntType
Int32)
Exp
hist_m' <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_m_prime" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32)
(VName -> Exp
Imp.vi32 VName
hist_L Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
hist_el_size)
(Exp
hist_N Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups'))
Exp -> Exp -> Exp
forall a. Fractional a => a -> a -> a
/ Exp -> Exp
forall v. PrimExp v -> PrimExp v
r64 Exp
hist_H
let hist_B :: Exp
hist_B = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
Exp
hist_M0 <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_M0" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) (Exp -> Exp
forall v. PrimExp v -> PrimExp v
t64 Exp
hist_m') Exp
hist_B
let q_small :: Exp
q_small = Exp
2
Exp
hist_Nout <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_Nout" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
segment_dims
Exp
hist_Nin <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_Nin" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> SubExp -> Exp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
space_sizes
Exp
work_asymp_M_max <-
if Bool
segmented then do
Exp
hist_T_hist_min <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_T_hist_min" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int32 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int64)
(IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 Exp
hist_Nin Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 Exp
hist_Nout) (IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 Exp
hist_T)
Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp`
IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 Exp
hist_Nout
let r :: Exp
r = Exp
hist_T_hist_min Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
hist_B
String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"work_asymp_M_max" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
hist_Nin Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` (Exp
r Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H)
else String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"work_asymp_M_max" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
(Exp
hist_Nout Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_N) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot`
((Exp
q_small Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_H)
Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` [SegHistSlug] -> Exp
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs)
VName
hist_M <- String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"hist_M" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32) Exp
hist_M0 Exp
work_asymp_M_max
let hist_M_nonzero :: Exp
hist_M_nonzero = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMax IntType
Int32) Exp
1 (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_M
Exp
hist_C <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_C" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
Exp
hist_B Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
hist_M_nonzero
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local hist_M0" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_M0
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local work asymp M max" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
work_asymp_M_max
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local C" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_C
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local B" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_B
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local M" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_M
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"local memory needed" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Exp
hist_H Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
hist_el_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
Imp.vi32 VName
hist_M
Exp
local_mem_needed <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"local_mem_needed" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Exp
hist_el_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> Exp
Imp.vi32 VName
hist_M
Exp
hist_S <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_S" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ (Exp
hist_H Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
local_mem_needed) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` VName -> Exp
Imp.vi32 VName
hist_L
let max_S :: Exp
max_S = case KernelBody KernelsMem -> Passage
bodyPassage KernelBody KernelsMem
kbody of
Passage
MustBeSinglePass -> Exp
1
Passage
MayBeMultiPass -> Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Exp) -> Int -> Exp
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegHistSlug -> Int) -> [SegHistSlug] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map SegHistSlug -> Int
slugMaxLocalMemPasses [SegHistSlug]
slugs
let pick_local :: Exp
pick_local =
Exp
hist_Nin Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>=. Exp
hist_H
Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. (Exp
local_mem_needed Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. VName -> Exp
Imp.vi32 VName
hist_L)
Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. (Exp
hist_S Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
max_S)
Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
hist_C Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
hist_B
Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. VName -> Exp
Imp.vi32 VName
hist_M Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. Exp
0
groups_per_segment :: Count NumGroups Exp
groups_per_segment
| Bool
segmented = Count NumGroups Exp
num_groups' Count NumGroups Exp -> Count NumGroups Exp -> Count NumGroups Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp -> Count NumGroups Exp
forall u e. e -> Count u e
Imp.Count Exp
hist_Nout
| Bool
otherwise = Count NumGroups Exp
num_groups'
run :: CallKernelGen ()
run = do
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"## Using local memory" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_H
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Multiplication degree (M)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
hist_M
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Cooperation level (C)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_C
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of chunks (S)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_S
Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
segmented (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Groups per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
groups_per_segment
VName
-> Count NumGroups Exp
-> [PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelLocal VName
hist_M Count NumGroups Exp
groups_per_segment [PatElem KernelsMem]
map_pes
Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space Exp
hist_S [SegHistSlug]
slugs KernelBody KernelsMem
kbody
(Exp, CallKernelGen ()) -> CallKernelGen (Exp, CallKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
pick_local, CallKernelGen ()
run)
compileSegHist :: Pattern KernelsMem
-> Count NumGroups SubExp -> Count GroupSize SubExp
-> SegSpace
-> [HistOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegHist :: Pattern KernelsMem
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [HistOp KernelsMem]
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegHist (Pattern [PatElem KernelsMem]
_ [PatElem KernelsMem]
pes) Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [HistOp KernelsMem]
ops KernelBody KernelsMem
kbody = do
Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count NumGroups SubExp
num_groups
Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp Count GroupSize SubExp
group_size
[Exp]
dims <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp])
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
let num_red_res :: Int
num_red_res = [HistOp KernelsMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp KernelsMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp KernelsMem -> [SubExp]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp KernelsMem]
ops)
([PatElemT LetDecMem]
all_red_pes, [PatElemT LetDecMem]
map_pes) = Int
-> [PatElemT LetDecMem]
-> ([PatElemT LetDecMem], [PatElemT LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res [PatElem KernelsMem]
[PatElemT LetDecMem]
pes
segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims
([Count Bytes Exp]
op_hs, [Count Bytes Exp]
op_seg_hs, [SegHistSlug]
slugs) <- [(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
-> ([Count Bytes Exp], [Count Bytes Exp], [SegHistSlug])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
-> ([Count Bytes Exp], [Count Bytes Exp], [SegHistSlug]))
-> ImpM
KernelsMem
HostEnv
HostOp
[(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
-> ImpM
KernelsMem
HostEnv
HostOp
([Count Bytes Exp], [Count Bytes Exp], [SegHistSlug])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp KernelsMem
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug))
-> [HistOp KernelsMem]
-> ImpM
KernelsMem
HostEnv
HostOp
[(Count Bytes Exp, Count Bytes Exp, SegHistSlug)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegSpace
-> HistOp KernelsMem
-> CallKernelGen (Count Bytes Exp, Count Bytes Exp, SegHistSlug)
computeHistoUsage SegSpace
space) [HistOp KernelsMem]
ops
Exp
h <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"h" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
Imp.unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes Exp]
op_hs
Exp
seg_h <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"seg_h" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
Imp.unCount (Count Bytes Exp -> Exp) -> Count Bytes Exp -> Exp
forall a b. (a -> b) -> a -> b
$ [Count Bytes Exp] -> Count Bytes Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Count Bytes Exp]
op_seg_hs
Exp -> CallKernelGen () -> CallKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless (Exp
seg_h Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let hist_B :: Exp
hist_B = Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
Exp
hist_H <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_H" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (HistOp KernelsMem -> Exp) -> [HistOp KernelsMem] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> (HistOp KernelsMem -> SubExp) -> HistOp KernelsMem -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth) [HistOp KernelsMem]
ops
let lockSize :: SegHistSlug -> Maybe a
lockSize SegHistSlug
slug = case SegHistSlug -> AtomicUpdate KernelsMem KernelEnv
slugAtomicUpdate SegHistSlug
slug of
AtomicLocking{} -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
int32
AtomicUpdate KernelsMem KernelEnv
_ -> Maybe a
forall a. Maybe a
Nothing
Exp
hist_el_size <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_el_size" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(+) (Exp
h Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
hist_H) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
(SegHistSlug -> Maybe Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SegHistSlug -> Maybe Exp
forall a. Num a => SegHistSlug -> Maybe a
lockSize [SegHistSlug]
slugs
Exp
hist_N <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_N" Exp
segment_size
Exp
hist_RF <- String -> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"hist_RF" (Exp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Exp -> ImpM KernelsMem HostEnv HostOp Exp
forall a b. (a -> b) -> a -> b
$
[Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SegHistSlug -> Exp) -> [SegHistSlug] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32(SubExp -> Exp) -> (SegHistSlug -> SubExp) -> SegHistSlug -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histRaceFactor (HistOp KernelsMem -> SubExp)
-> (SegHistSlug -> HistOp KernelsMem) -> SegHistSlug -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegHistSlug -> HistOp KernelsMem
slugOp) [SegHistSlug]
slugs)
Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot`
[SegHistSlug] -> Exp
forall i a. Num i => [a] -> i
genericLength [SegHistSlug]
slugs
let hist_T :: Exp
hist_T = Count NumGroups Exp -> Exp
forall u e. Count u e -> e
unCount Count NumGroups Exp
num_groups' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegHist" Maybe Exp
forall a. Maybe a
Nothing
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of threads (T)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_T
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Desired group size (B)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_B
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram size (H)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_H
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Input elements per histogram (N)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_N
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Number of segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
segment_dims
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Histogram element size (el_size)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_el_size
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Race factor (RF)" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
hist_RF
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms per segment" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
h
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Memory per set of subhistograms times segments" (Maybe Exp -> Code HostOp) -> Maybe Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just Exp
seg_h
(Exp
use_local_memory, CallKernelGen ()
run_in_local_memory) <-
[PatElem KernelsMem]
-> Exp
-> SegSpace
-> Exp
-> Exp
-> Exp
-> Exp
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen (Exp, CallKernelGen ())
localMemoryCase [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes Exp
hist_T SegSpace
space Exp
hist_H Exp
hist_el_size Exp
hist_N Exp
hist_RF [SegHistSlug]
slugs KernelBody KernelsMem
kbody
Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
use_local_memory CallKernelGen ()
run_in_local_memory (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
[PatElem KernelsMem]
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> SegSpace
-> [SegHistSlug]
-> KernelBody KernelsMem
-> CallKernelGen ()
histKernelGlobal [PatElem KernelsMem]
[PatElemT LetDecMem]
map_pes Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegSpace
space [SegHistSlug]
slugs KernelBody KernelsMem
kbody
let pes_per_op :: [[PatElemT LetDecMem]]
pes_per_op = [Int] -> [PatElemT LetDecMem] -> [[PatElemT LetDecMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp KernelsMem -> [VName]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp KernelsMem]
ops) [PatElemT LetDecMem]
all_red_pes
[(SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)]
-> ((SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)
-> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegHistSlug]
-> [[PatElemT LetDecMem]]
-> [HistOp KernelsMem]
-> [(SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SegHistSlug]
slugs [[PatElemT LetDecMem]]
pes_per_op [HistOp KernelsMem]
ops) (((SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)
-> CallKernelGen ())
-> CallKernelGen ())
-> ((SegHistSlug, [PatElemT LetDecMem], HistOp KernelsMem)
-> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegHistSlug
slug, [PatElemT LetDecMem]
red_pes, HistOp KernelsMem
op) -> do
let num_histos :: VName
num_histos = SegHistSlug -> VName
slugNumSubhistos SegHistSlug
slug
subhistos :: [VName]
subhistos = (SubhistosInfo -> VName) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map SubhistosInfo -> VName
subhistosArray ([SubhistosInfo] -> [VName]) -> [SubhistosInfo] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegHistSlug -> [SubhistosInfo]
slugSubhistos SegHistSlug
slug
let unitHistoCase :: CallKernelGen ()
unitHistoCase =
[(PatElemT LetDecMem, VName)]
-> ((PatElemT LetDecMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [VName] -> [(PatElemT LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
red_pes [VName]
subhistos) (((PatElemT LetDecMem, VName) -> CallKernelGen ())
-> CallKernelGen ())
-> ((PatElemT LetDecMem, VName) -> CallKernelGen ())
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, VName
subhisto) -> do
VName
pe_mem <- MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> VName)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
VName
subhisto_mem <- MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> VName)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
subhisto
Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code HostOp
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
pe_mem VName
subhisto_mem (Space -> Code HostOp) -> Space -> Code HostOp
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"device"
Exp -> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (VName -> PrimType -> Exp
Imp.var VName
num_histos PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
1) CallKernelGen ()
unitHistoCase (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let num_buckets :: SubExp
num_buckets = HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op
VName
bucket_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id"
VName
subhistogram_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"subhistogram_id"
[VName]
vector_ids <- (SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ImpM KernelsMem HostEnv HostOp VName
-> SubExp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. a -> b -> a
const (ImpM KernelsMem HostEnv HostOp VName
-> SubExp -> ImpM KernelsMem HostEnv HostOp VName)
-> ImpM KernelsMem HostEnv HostOp VName
-> SubExp
-> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"vector_id") ([SubExp] -> ImpM KernelsMem HostEnv HostOp [VName])
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [VName]
forall a b. (a -> b) -> a -> b
$
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op
VName
flat_gtid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_gtid"
let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegVirt
segred_space :: SegSpace
segred_space =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat_gtid ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
segment_dims [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++
[(VName
bucket_id, SubExp
num_buckets)] [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++
[VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
vector_ids (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op) [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++
[(VName
subhistogram_id, VName -> SubExp
Var VName
num_histos)]
let segred_op :: SegBinOp KernelsMem
segred_op = Commutativity
-> LambdaT KernelsMem -> [SubExp] -> Shape -> SegBinOp KernelsMem
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
Commutative (HistOp KernelsMem -> LambdaT KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op) (HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral HistOp KernelsMem
op) Shape
forall a. Monoid a => a
mempty
Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> DoSegBody
-> CallKernelGen ()
compileSegRed' ([PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT LetDecMem]
red_pes) SegLevel
lvl SegSpace
segred_space [SegBinOp KernelsMem
segred_op] (DoSegBody -> CallKernelGen ()) -> DoSegBody -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[(SubExp, [Exp])] -> InKernelGen ()
red_cont ->
[(SubExp, [Exp])] -> InKernelGen ()
red_cont ([(SubExp, [Exp])] -> InKernelGen ())
-> [(SubExp, [Exp])] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ ((VName -> (SubExp, [Exp])) -> [VName] -> [(SubExp, [Exp])])
-> [VName] -> (VName -> (SubExp, [Exp])) -> [(SubExp, [Exp])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [Exp])) -> [VName] -> [(SubExp, [Exp])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
subhistos ((VName -> (SubExp, [Exp])) -> [(SubExp, [Exp])])
-> (VName -> (SubExp, [Exp])) -> [(SubExp, [Exp])]
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
(VName -> SubExp
Var VName
subhisto, (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 ([VName] -> [Exp]) -> [VName] -> [Exp]
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id, VName
bucket_id] [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
vector_ids)
where segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space