{-# LANGUAGE TypeFamilies #-}

-- | Generation of kernels with group-level bodies.
module Futhark.CodeGen.ImpGen.GPU.Group
  ( sKernelGroup,
    compileGroupResult,
    groupOperations,

    -- * Precomputation
    Precomputed,
    precomputeConstants,
    precomputedConstants,
    atomicUpdateLocking,
  )
where

import Control.Monad.Except
import Data.Bifunctor
import Data.List (partition, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Construct (fullSliceNum)
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Util (chunks, mapAccumLM, takeLast)
import Futhark.Util.IntegralExp (divUp, rem)
import Prelude hiding (quot, rem)

-- | @flattenArray k flat arr@ flattens the outer @k@ dimensions of
-- @arr@ to @flat@.  (Make sure @flat@ is the sum of those dimensions
-- or you'll have a bad time.)
flattenArray :: Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray :: forall {k} (rep :: k) r op.
Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray Int
k TV Int64
flat VName
arr = do
  ArrayEntry MemLoc
arr_loc PrimType
pt <- forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  let flat_shape :: ShapeBase SubExp
flat_shape = forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
flat) forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop Int
k (MemLoc -> [SubExp]
memLocShape MemLoc
arr_loc)
  forall {k} (rep :: k) r op.
[Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op VName
sArray (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_flat") PrimType
pt ShapeBase SubExp
flat_shape (MemLoc -> VName
memLocName MemLoc
arr_loc) forall a b. (a -> b) -> a -> b
$
    forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.reshape (MemLoc -> IxFun
memLocIxFun MemLoc
arr_loc) forall a b. (a -> b) -> a -> b
$
      forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$
        forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
flat_shape

sliceArray :: Imp.TExp Int64 -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray :: forall {k} (rep :: k) r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray TPrimExp Int64 VName
start TV Int64
size VName
arr = do
  MemLoc VName
mem [SubExp]
_ IxFun
ixfun <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  TypeBase (ShapeBase SubExp) NoUniqueness
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
arr
  let slice :: Slice (TPrimExp Int64 VName)
slice =
        forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum
          (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
arr_t))
          [forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
start (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
size) TPrimExp Int64 VName
1]
  forall {k} (rep :: k) r op.
[Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op VName
sArray
    (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_chunk")
    (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
arr_t)
    (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
arr_t forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
size))
    VName
mem
    forall a b. (a -> b) -> a -> b
$ forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun Slice (TPrimExp Int64 VName)
slice

-- | @applyLambda lam dests args@ emits code that:
--
-- 1. Binds each parameter of @lam@ to the corresponding element of
--    @args@, interpreted as a (name,slice) pair (as in 'copyDWIM').
--    Use an empty list for a scalar.
--
-- 2. Executes the body of @lam@.
--
-- 3. Binds the t'SubExp's that are the 'Result' of @lam@ to the
-- provided @dest@s, again interpreted as the destination for a
-- 'copyDWIM'.
applyLambda ::
  Mem rep inner =>
  Lambda rep ->
  [(VName, [DimIndex (Imp.TExp Int64)])] ->
  [(SubExp, [DimIndex (Imp.TExp Int64)])] ->
  ImpM rep r op ()
applyLambda :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
  forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args) forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
arg, [DimIndex (TPrimExp Int64 VName)]
arg_slice)) ->
    forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
arg [DimIndex (TPrimExp Int64 VName)]
arg_slice
  forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ do
    let res :: [SubExp]
res = forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [SubExp]
res) forall a b. (a -> b) -> a -> b
$ \((VName
dest, [DimIndex (TPrimExp Int64 VName)]
dest_slice), SubExp
se) ->
      forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TPrimExp Int64 VName)]
dest_slice SubExp
se []

-- | As applyLambda, but first rename the names in the lambda.  This
-- makes it safe to apply it in multiple places.  (It might be safe
-- anyway, but you have to be more careful - use this if you are in
-- doubt.)
applyRenamedLambda ::
  Mem rep inner =>
  Lambda rep ->
  [(VName, [DimIndex (Imp.TExp Int64)])] ->
  [(SubExp, [DimIndex (Imp.TExp Int64)])] ->
  ImpM rep r op ()
applyRenamedLambda :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
  Lambda rep
lam_renamed <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
  forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam_renamed [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args

groupChunkLoop ::
  Imp.TExp Int32 ->
  (Imp.TExp Int32 -> TV Int64 -> InKernelGen ()) ->
  InKernelGen ()
groupChunkLoop :: TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w TExp Int32 -> TV Int64 -> InKernelGen ()
m = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  let max_chunk_size :: TExp Int32
max_chunk_size = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  TExp Int32
num_chunks <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" forall a b. (a -> b) -> a -> b
$ TExp Int32
w forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
max_chunk_size
  forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp Int32
num_chunks forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
    TExp Int32
chunk_start <-
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"chunk_start" forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i forall a. Num a => a -> a -> a
* TExp Int32
max_chunk_size
    TExp Int32
chunk_end <-
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"chunk_end" forall a b. (a -> b) -> a -> b
$ forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMin32 TExp Int32
w (TExp Int32
chunk_start forall a. Num a => a -> a -> a
+ TExp Int32
max_chunk_size)
    TV Int64
chunk_size <-
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"chunk_size" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_end forall a. Num a => a -> a -> a
- TExp Int32
chunk_start
    TExp Int32 -> TV Int64 -> InKernelGen ()
m TExp Int32
chunk_start TV Int64
chunk_size

virtualisedGroupScan ::
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
virtualisedGroupScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
    let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
        crosses_segment :: TExp Bool
crosses_segment =
          case Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag of
            Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing -> forall v. TPrimExp Bool v
false
            Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true ->
              TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int32
chunk_start forall a. Num a => a -> a -> a
- TExp Int32
1)) (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int32
chunk_start)
    forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"possibly incorporate carry" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
crosses_segment) forall a b. (a -> b) -> a -> b
$ do
        TPrimExp Int64 VName
carry_idx <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"carry_idx" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
        forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
          Lambda GPUMem
lam
          (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
          ( forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs) (forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
carry_idx])
              forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs) (forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
          )

    [VName]
arrs_chunks <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
arrs

    forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

    Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
w) (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_size) Lambda GPUMem
lam [VName]
arrs_chunks

copyInGroup :: CopyCompiler GPUMem KernelEnv Imp.KernelOp
copyInGroup :: CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup PrimType
pt MemLoc
destloc MemLoc
srcloc = do
  Space
dest_space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
destloc)
  Space
src_space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
srcloc)

  let src_ixfun :: IxFun
src_ixfun = MemLoc -> IxFun
memLocIxFun MemLoc
srcloc
      dims :: [TPrimExp Int64 VName]
dims = forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape IxFun
src_ixfun
      rank :: Int
rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims

  case (Space
dest_space, Space
src_space) of
    (ScalarSpace [SubExp]
destds PrimType
_, ScalarSpace [SubExp]
srcds PrimType
_) -> do
      let fullDim :: d -> DimIndex d
fullDim d
d = forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
          destslice' :: Slice (TPrimExp Int64 VName)
destslice' =
            forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
              forall a. Int -> a -> [a]
replicate (Int
rank forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
                forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
          srcslice' :: Slice (TPrimExp Int64 VName)
srcslice' =
            forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
              forall a. Int -> a -> [a]
replicate (Int
rank forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
                forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
      forall {k} (rep :: k) r op. CopyCompiler rep r op
copyElementWise
        PrimType
pt
        (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc Slice (TPrimExp Int64 VName)
destslice')
        (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc Slice (TPrimExp Int64 VName)
srcslice')
    (Space, Space)
_ -> do
      forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims) forall a b. (a -> b) -> a -> b
$ \[TExp Int32]
is ->
        forall {k} (rep :: k) r op. CopyCompiler rep r op
copyElementWise
          PrimType
pt
          (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
          (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

localThreadIDs :: [SubExp] -> InKernelGen [Imp.TExp Int64]
localThreadIDs :: [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims = do
  TPrimExp Int64 VName
ltid <- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  let dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
  forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall {k} (rep :: k) r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"ltid" [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
ltid) (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k) r op. ImpM rep r op r
askEnv

partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims (SegSeqDims [Int]
seq_is) SegSpace
space =
  forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
    forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
seq_is) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) (forall a b. [a] -> [b] -> [(a, b)]
zip (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) [Int
0 ..])

sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel SegThread {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
sanityCheckLevel SegGroup {} =
  forall a. HasCallStack => [Char] -> a
error [Char]
"compileGroupOp: unexpected group-level SegOp."

compileFlatId :: SegLevel -> SegSpace -> InKernelGen ()
compileFlatId :: SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space = do
  SegLevel -> InKernelGen ()
sanityCheckLevel SegLevel
lvl
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) TExp Int32
ltid

-- Construct the necessary lock arrays for an intra-group histogram.
prepareIntraGroupSegHist ::
  Count GroupSize SubExp ->
  [HistOp GPUMem] ->
  InKernelGen [[Imp.TExp Int64] -> InKernelGen ()]
prepareIntraGroupSegHist :: Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist Count GroupSize SubExp
group_size =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM Maybe Locking
-> HistOp GPUMem
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp forall a. Maybe a
Nothing
  where
    onOp :: Maybe Locking
-> HistOp GPUMem
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp Maybe Locking
l HistOp GPUMem
op = do
      KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
      AtomicBinOp
atomicBinOp <- KernelEnv -> AtomicBinOp
kernelAtomics forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv

      let local_subhistos :: [VName]
local_subhistos = forall {k} (rep :: k). HistOp rep -> [VName]
histDest HistOp GPUMem
op

      case (Maybe Locking
l, AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) of
        (Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
        (Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
        (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
        (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
          VName
locks <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"locks"

          let num_locks :: TPrimExp Int64 VName
num_locks = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
              dims :: [TPrimExp Int64 VName]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims (forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histOpShape HistOp GPUMem
op forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape HistOp GPUMem
op)
              l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
num_locks) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims)
              locks_t :: TypeBase (ShapeBase SubExp) NoUniqueness
locks_t = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 (forall d. [d] -> ShapeBase d
Shape [forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size]) NoUniqueness
NoUniqueness

          VName
locks_mem <- forall {k} (rep :: k) r op.
[Char]
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM rep r op VName
sAlloc [Char]
"locks_mem" (TypeBase (ShapeBase SubExp) NoUniqueness
-> Count Bytes (TPrimExp Int64 VName)
typeSize TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"
          forall {k} (rep :: k) r op.
VName
-> PrimType
-> ShapeBase SubExp
-> VName
-> IxFun
-> ImpM rep r op ()
dArray VName
locks PrimType
int32 (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) VName
locks_mem forall a b. (a -> b) -> a -> b
$
            forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$
              TypeBase (ShapeBase SubExp) NoUniqueness
locks_t

          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"All locks start out unlocked" forall a b. (a -> b) -> a -> b
$
            forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants] forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is ->
              forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
locks [TPrimExp Int64 VName]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)

groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
virt SegSpace
space InKernelGen ()
m = do
  let ([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims

  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  let group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  -- Maybe we can statically detect that this is actually a
  -- SegNoVirtFull and generate ever-so-slightly simpler code.
  let virt' :: SegVirt
virt' = if [TPrimExp Int64 VName]
dims' forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName
group_size] then SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []) else SegVirt
virt
  case SegVirt
virt' of
    SegVirt
SegVirt -> do
      Maybe (TExp Int32)
iters <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] (TExp Int32)
kernelChunkItersMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
      case Maybe (TExp Int32)
iters of
        Maybe (TExp Int32)
Nothing -> do
          TExp Int32
iterations <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"iterations" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims'
          forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop TExp Int32
iterations forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
            forall {k} (rep :: k) r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
            InKernelGen ()
m
        Just TExp Int32
num_chunks -> do
          let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
          forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp Int32
num_chunks forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
            TExp Int32
i <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"i" forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
group_size forall a. Num a => a -> a -> a
+ TExp Int32
ltid
            forall {k} (rep :: k) r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
            forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> TPrimExp Int64 a
le64) [VName]
ltids)) [TPrimExp Int64 VName]
dims') InKernelGen ()
m
    SegVirt
SegNoVirt -> forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$ do
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [SubExp]
dims) InKernelGen ()
m
    SegNoVirtFull SegSeqDims
seq_dims -> do
      let (([VName]
ltids_seq, [SubExp]
dims_seq), ([VName]
ltids_par, [SubExp]
dims_par)) =
            forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims SegSpace
space
      forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims_seq) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_seq -> do
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_seq [TPrimExp Int64 VName]
is_seq
        forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$ do
          forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k1} {k2} (t :: k1) (rep :: k2) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_par forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims_par
          InKernelGen ()
m

compileGroupExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupExp :: ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Opaque OpaqueOp
_ SubExp
se)) =
  -- Cannot print in GPU code.
  forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
pe) [] SubExp
se []
-- The static arrays stuff does not work inside kernels.
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase (ShapeBase SubExp) NoUniqueness
_)) =
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
    forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileGroupExp Pat (LetDec GPUMem)
_ (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs)) =
  VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc VName
acc [SubExp]
is [SubExp]
vs
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Replicate ShapeBase SubExp
ds SubExp
se)) = do
  VName
flat <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"rep_flat"
  [VName]
is <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
ds) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"rep_i")
  let is' :: [TPrimExp Int64 VName]
is' = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is
  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
SegVirt (VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
ds) forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [TPrimExp Int64 VName]
is' SubExp
se []
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Rotate [SubExp]
rs VName
arr)) = do
  [TPrimExp Int64 VName]
ds <- forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
arr
  forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TPrimExp Int64 VName]
ds forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
    [TPrimExp Int64 VName]
is' <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 forall {k2} {rep :: k2} {r} {op}.
TPrimExp Int64 VName
-> SubExp
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate [TPrimExp Int64 VName]
ds [SubExp]
rs [TPrimExp Int64 VName]
is
    forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [TPrimExp Int64 VName]
is (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
is'
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
  where
    rotate :: TPrimExp Int64 VName
-> SubExp
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate TPrimExp Int64 VName
d SubExp
r TPrimExp Int64 VName
i = forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"rot_i" forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
rotateIndex TPrimExp Int64 VName
d (SubExp -> TPrimExp Int64 VName
pe64 SubExp
r) TPrimExp Int64 VName
i
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
it)) = do
  Exp
n' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
n
  Exp
e' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  Exp
s' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
s
  forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp Exp
n') forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i' -> do
    TV Any
x <-
      forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"x" forall a b. (a -> b) -> a -> b
$
        forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$
          forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' forall a b. (a -> b) -> a -> b
$
            forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
i') Exp
s'
    forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [TPrimExp Int64 VName
i'] (VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Any
x)) []
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

-- When generating code for a scalar in-place update, we must make
-- sure that only one thread performs the write.  When writing an
-- array, the group-level copy code will take care of doing the right
-- thing.
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se))
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice = do
      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
      TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
        case Safety
safety of
          Safety
Unsafe -> InKernelGen ()
write
          Safety
Safe -> forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds Slice (TPrimExp Int64 VName)
slice' [TPrimExp Int64 VName]
dims) InKernelGen ()
write
      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
  where
    slice' :: Slice (TPrimExp Int64 VName)
slice' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice
    dims :: [TPrimExp Int64 VName]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
    write :: InKernelGen ()
write = forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
pe) (forall d. Slice d -> [DimIndex d]
unSlice Slice (TPrimExp Int64 VName)
slice') SubExp
se []
compileGroupExp Pat (LetDec GPUMem)
dest Exp GPUMem
e =
  forall {k} (rep :: k) inner r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e

compileGroupOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupOp :: OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp Pat (LetDec GPUMem)
pat (Alloc SubExp
size Space
space) =
  Pat LetDecMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pat (LetDec GPUMem)
pat SubExp
size Space
space
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
pat) forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space

  let ([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
        forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
          VName
dest
          (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids)
          (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
          []

  Fence
fence <- [VName] -> InKernelGen Fence
fenceForArrays forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence

  let segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
      crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
        (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)

  -- groupScan needs to treat the scan output as a one-dimensional
  -- array of scan elements, so we invent some new flattened arrays
  -- here.
  TV Int64
dims_flat <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
  let scan :: SegBinOp GPUMem
scan = forall a. [a] -> a
head [SegBinOp GPUMem]
scans
      num_scan_results :: Int
num_scan_results = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan
  [VName]
arrs_flat <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) forall a b. (a -> b) -> a -> b
$
      forall a. Int -> [a] -> [a]
take Int
num_scan_results forall a b. (a -> b) -> a -> b
$
        forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat

  case SegLevel -> SegVirt
segVirt SegLevel
lvl of
    SegVirt
SegVirt ->
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
        (forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
        (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dims_flat)
        (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
        [VName]
arrs_flat
    SegVirt
_ ->
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
        (forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
        (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
        (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
        (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
        [VName]
arrs_flat
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space

  let dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      mkTempArr :: TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr TypeBase (ShapeBase SubExp) NoUniqueness
t =
        forall {k} (rep :: k) r op.
[Char]
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray [Char]
"red_arr" (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
t) (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims forall a. Semigroup a => a -> a -> a
<> forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
t) forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"

  [VName]
tmp_arrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k).
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
ops
  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) =
            forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmp_arrs [KernelResult]
red_res) forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
        forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LetDecMem]
map_pes [KernelResult]
map_res

  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

  let tmps_for_ops :: [[VName]]
tmps_for_ops = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
ops) [VName]
tmp_arrs
  case SegLevel -> SegVirt
segVirt SegLevel
lvl of
    SegVirt
SegVirt -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
    SegVirt
_ -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
  where
    ([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
    ([PatElem LetDecMem]
red_pes, [PatElem LetDecMem]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
pat

    virtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
      TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
      TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
        forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"possibly incorporate carry" forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
              forall {k} (rep :: k) inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
                (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
                (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmps forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
                ( forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem LetDecMem]
red_pes) (forall a. a -> [a]
repeat [])
                    forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
tmps) (forall a. a -> [a]
repeat [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
                )

        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
          [VName]
tmps_chunks <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
tmps
          TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_size)) (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps_chunks

        forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
          forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]
    virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
      TV Int64
dims_flat <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      let segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
          crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
            (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall e. IntegralExp e => e -> e -> e
`rem` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
        [VName]
tmps_flat <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
        Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
          (forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
          (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dims_flat)
          (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
          [VName]
tmps_flat

      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
        forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
          (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
          []
          (VName -> SubExp
Var VName
arr)
          (forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') forall a. [a] -> [a] -> [a]
++ [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [TPrimExp Int64 VName]
dims' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])

      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

    nonvirtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
      -- Nonsegmented case (or rather, a single segment) - this we can
      -- handle directly with a group-level reduction.
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
        TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps
      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
        forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
0]
    --
    nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
      -- Segmented intra-group reductions are turned into (regular)
      -- segmented scans.  It is possible that this can be done
      -- better, but at least this approach is simple.

      -- groupScan operates on flattened arrays.  This does not
      -- involve copying anything; merely playing with the index
      -- function.
      TV Int64
dims_flat <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      let segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
          crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
            (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall e. IntegralExp e => e -> e -> e
`rem` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
        [VName]
tmps_flat <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) r op.
Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
        Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
          (forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
          (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
          (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
          (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
          [VName]
tmps_flat

      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
        forall {k} (rep :: k) r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
          (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
          []
          (VName -> SubExp
Var VName
arr)
          (forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') forall a. [a] -> [a] -> [a]
++ [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [TPrimExp Int64 VName]
dims' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])

      forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
kbody))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space
  let ([VName]
ltids, [SubExp]
_dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  -- We don't need the red_pes, because it is guaranteed by our type
  -- rules that they occupy the same memory as the destinations for
  -- the ops.
  let num_red_res :: Int
num_red_res = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
      ([PatElem LetDecMem]
_red_pes, [PatElem LetDecMem]
map_pes) =
        forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
pat

  [[TPrimExp Int64 VName] -> InKernelGen ()]
ops' <- Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) [HistOp GPUMem]
ops

  -- Ensure that all locks have been initialised.
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
          ([SubExp]
red_is, [SubExp]
red_vs) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LetDecMem]
map_pes [KernelResult]
map_res

      let vs_per_op :: [[SubExp]]
vs_per_op = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [SubExp]
red_vs

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SubExp]
red_is [[SubExp]]
vs_per_op [[TPrimExp Int64 VName] -> InKernelGen ()]
ops' [HistOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$
        \(SubExp
bin, [SubExp]
op_vs, [TPrimExp Int64 VName] -> InKernelGen ()
do_op, HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda GPUMem
lam) -> do
          let bin' :: TPrimExp Int64 VName
bin' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
bin
              dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
              bin_in_bounds :: TExp Bool
bin_in_bounds = Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName
bin'])) [TPrimExp Int64 VName]
dest_shape'
              bin_is :: [TPrimExp Int64 VName]
bin_is = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
ltids) forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName
bin']
              vs_params :: [Param LetDecMem]
vs_params = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

          forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bin_in_bounds forall a b. (a -> b) -> a -> b
$ do
              forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
              forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
op_vs) forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
                  forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [TPrimExp Int64 VName]
is
                [TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bin_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)

  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat Op GPUMem
_ =
  forall a. [Char] -> a
compilerBugS forall a b. (a -> b) -> a -> b
$ [Char]
"compileGroupOp: cannot compile rhs of binding " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec GPUMem)
pat

groupOperations :: Operations GPUMem KernelEnv Imp.KernelOp
groupOperations :: Operations GPUMem KernelEnv KernelOp
groupOperations =
  (forall {k} (rep :: k) inner op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp)
    { opsCopyCompiler :: CopyCompiler GPUMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup,
      opsExpCompiler :: ExpCompiler GPUMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp,
      opsStmsCompiler :: StmsCompiler GPUMem KernelEnv KernelOp
opsStmsCompiler = \Names
_ -> forall {k} (rep :: k) inner op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms forall a. Monoid a => a
mempty,
      opsAllocCompilers :: Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
opsAllocCompilers =
        forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [([Char] -> Space
Space [Char]
"local", forall r. AllocCompiler GPUMem r KernelOp
allocLocal)]
    }

arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var VName
name) = do
  VarEntry GPUMem
res <- forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry GPUMem
res of
    ArrayVar Maybe (Exp GPUMem)
_ ArrayEntry
entry ->
      ([Char] -> Space
Space [Char]
"local" ==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
entry))
    VarEntry GPUMem
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
arrayInLocalMemory Constant {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

sKernelGroup ::
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernelGroup :: [Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelGroup = Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ()
sKernel Operations GPUMem KernelEnv KernelOp
groupOperations KernelConstants -> TExp Int32
kernelGroupId

compileGroupResult ::
  SegSpace ->
  PatElem LetDecMem ->
  KernelResult ->
  InKernelGen ()
compileGroupResult :: SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
_ PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp
w, SubExp
per_group_elems)] VName
what) = do
  TPrimExp Int64 VName
n <- SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
what

  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  let ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
      offset :: TPrimExp Int64 VName
offset =
        SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems
          forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)

  -- Avoid loop for the common case where each thread is statically
  -- known to write at most one element.
  forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
    if SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems forall a. Eq a => a -> a -> Bool
== KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
      then
        forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
ltid]
      else forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" (TPrimExp Int64 VName
n forall e. IntegralExp e => e -> e -> e
`divUp` KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
        TPrimExp Int64 VName
j <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
        forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
j forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName
j forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
j]
compileGroupResult SegSpace
space PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
what) = do
  let gids :: [VName]
gids = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      out_tile_sizes :: [TPrimExp Int64 VName]
out_tile_sizes = forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(SubExp, SubExp)]
dims
      group_is :: [TPrimExp Int64 VName]
group_is = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids) [TPrimExp Int64 VName]
out_tile_sizes
  [TPrimExp Int64 VName]
local_is <- [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims
  [TV Int64]
is_for_thread <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"thread_out_index") forall a b. (a -> b) -> a -> b
$
      forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(+) [TPrimExp Int64 VName]
group_is [TPrimExp Int64 VName]
local_is

  forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k). TV t -> VName
tvVar [TV Int64]
is_for_thread) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims) forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k). TV t -> TExp t
tvExp [TV Int64]
is_for_thread) (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
local_is
compileGroupResult SegSpace
space PatElem LetDecMem
pe (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv

  let gids :: [VName]
gids = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ([SubExp]
dims, [SubExp]
group_tiles, [SubExp]
reg_tiles) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
      group_tiles' :: [TPrimExp Int64 VName]
group_tiles' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
group_tiles
      reg_tiles' :: [TPrimExp Int64 VName]
reg_tiles' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
reg_tiles

  -- Which group tile is this group responsible for?
  let group_tile_is :: [TPrimExp Int64 VName]
group_tile_is = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids

  -- Within the group tile, which register tile is this thread
  -- responsible for?
  [TPrimExp Int64 VName]
reg_tile_is <-
    forall {k} (rep :: k) r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"reg_tile_i" [TPrimExp Int64 VName]
group_tiles' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants

  -- Compute output array slice for the register tile belonging to
  -- this thread.
  let regTileSliceDim :: (TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim (TExp t
group_tile, TExp t
group_tile_i) (TExp t
reg_tile, TExp t
reg_tile_i) = do
        TExp t
tile_dim_start <-
          forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"tile_dim_start" forall a b. (a -> b) -> a -> b
$
            TExp t
reg_tile forall a. Num a => a -> a -> a
* (TExp t
group_tile forall a. Num a => a -> a -> a
* TExp t
group_tile_i forall a. Num a => a -> a -> a
+ TExp t
reg_tile_i)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d. d -> d -> d -> DimIndex d
DimSlice TExp t
tile_dim_start TExp t
reg_tile TExp t
1
  Slice (TPrimExp Int64 VName)
reg_tile_slices <-
    forall d. [DimIndex d] -> Slice d
Slice
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
        forall {k1} {k2} {t :: k1} {rep :: k2} {r} {op}.
NumExp t =>
(TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim
        (forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
group_tiles' [TPrimExp Int64 VName]
group_tile_is)
        (forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
reg_tiles' [TPrimExp Int64 VName]
reg_tile_is)

  forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall d. [d] -> ShapeBase d
Shape [SubExp]
reg_tiles) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_in_reg_tile -> do
      let dest_is :: [TPrimExp Int64 VName]
dest_is = forall d. Num d => Slice d -> [d] -> [d]
fixSlice Slice (TPrimExp Int64 VName)
reg_tile_slices [TPrimExp Int64 VName]
is_in_reg_tile
          src_is :: [TPrimExp Int64 VName]
src_is = [TPrimExp Int64 VName]
reg_tile_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is_in_reg_tile
      forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) [TPrimExp Int64 VName]
dest_is forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims) forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
dest_is (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
src_is
compileGroupResult SegSpace
space PatElem LetDecMem
pe (Returns ResultManifest
_ Certs
_ SubExp
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  Bool
in_local_memory <- SubExp -> InKernelGen Bool
arrayInLocalMemory SubExp
what
  let gids :: [TPrimExp Int64 VName]
gids = forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> TPrimExp Int64 a
Imp.le64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  if Bool -> Bool
not Bool
in_local_memory
    then
      forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
    else -- If the result of the group is an array in local memory, we
    -- store it by collective copying among all the threads of the
    -- group.  TODO: also do this if the array is in global memory
    -- (but this is a bit more tricky, synchronisation-wise).
      forall {k} (rep :: k) r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
compileGroupResult SegSpace
_ PatElem LetDecMem
_ WriteReturns {} =
  forall a. [Char] -> a
compilerLimitationS [Char]
"compileGroupResult: WriteReturns not handled yet."

-- | The sizes of nested iteration spaces in the kernel.
type SegOpSizes = S.Set [SubExp]

-- | Various useful precomputed information for group-level SegOps.
data Precomputed = Precomputed
  { Precomputed -> SegOpSizes
pcSegOpSizes :: SegOpSizes,
    Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32)
  }

-- | Find the sizes of nested parallelism in a t'SegOp' body.
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes = Stms GPUMem -> SegOpSizes
onStms
  where
    onStms :: Stms GPUMem -> SegOpSizes
onStms = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp GPUMem -> SegOpSizes
onExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp)
    onExp :: Exp GPUMem -> SegOpSizes
onExp (Op (Inner (SegOp SegOp SegLevel GPUMem
op))) =
      case SegLevel -> SegVirt
segVirt forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op of
        SegNoVirtFull SegSeqDims
seq_dims ->
          forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
        SegVirt
_ -> forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
    onExp (BasicOp (Replicate ShapeBase SubExp
shape SubExp
_)) =
      forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
    onExp (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_) =
      forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Stms GPUMem -> SegOpSizes
onStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPUMem)]
cases forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> SegOpSizes
onStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPUMem
defbody)
    onExp (DoLoop [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ Body GPUMem
body) =
      Stms GPUMem -> SegOpSizes
onStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPUMem
body)
    onExp Exp GPUMem
_ = forall a. Monoid a => a
mempty

-- | Precompute various constants and useful information.
precomputeConstants :: Count GroupSize (Imp.TExp Int64) -> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants :: Count GroupSize (TPrimExp Int64 VName)
-> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants Count GroupSize (TPrimExp Int64 VName)
group_size Stms GPUMem
stms = do
  let sizes :: SegOpSizes
sizes = Stms GPUMem -> SegOpSizes
segOpSizes Stms GPUMem
stms
  Map [SubExp] (TExp Int32)
iters_map <- forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap (forall a. Set a -> [a]
S.toList SegOpSizes
sizes)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SegOpSizes -> Map [SubExp] (TExp Int32) -> Precomputed
Precomputed SegOpSizes
sizes Map [SubExp] (TExp Int32)
iters_map
  where
    mkMap :: [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap [SubExp]
dims = do
      let n :: TPrimExp Int64 VName
n = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 [SubExp]
dims
      TExp Int32
num_chunks <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
n forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, TExp Int32
num_chunks)

-- | Make use of various precomputed constants.
precomputedConstants :: Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants :: forall a. Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants Precomputed
pre InKernelGen a
m = do
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  Map [SubExp] [TExp Int32]
new_ids <- forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} {t :: k} {rep :: k} {r} {op}.
IntExp t =>
TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TExp Int32
ltid) (forall a. Set a -> [a]
S.toList (Precomputed -> SegOpSizes
pcSegOpSizes Precomputed
pre))
  let f :: KernelEnv -> KernelEnv
f KernelEnv
env =
        KernelEnv
env
          { kernelConstants :: KernelConstants
kernelConstants =
              (KernelEnv -> KernelConstants
kernelConstants KernelEnv
env)
                { kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
new_ids,
                  kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap Precomputed
pre
                }
          }
  forall {k} r (rep :: k) op a.
(r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv KernelEnv -> KernelEnv
f InKernelGen a
m
  where
    mkMap :: TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TPrimExp t VName
ltid [SubExp]
dims = do
      let dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      [TPrimExp Int64 VName]
ids' <- forall {k} (rep :: k) r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"ltid_pre" [TPrimExp Int64 VName]
dims' (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp t VName
ltid)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
ids')