{-# LANGUAGE TypeFamilies #-}

-- | Generation of kernels with group-level bodies.
module Futhark.CodeGen.ImpGen.GPU.Group
  ( sKernelGroup,
    compileGroupResult,
    groupOperations,

    -- * Precomputation
    Precomputed,
    precomputeConstants,
    precomputedConstants,
    atomicUpdateLocking,
  )
where

import Control.Monad
import Data.Bifunctor
import Data.List (partition, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Construct (fullSliceNum)
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (chunks, mapAccumLM, takeLast)
import Futhark.Util.IntegralExp (divUp, rem)
import Prelude hiding (quot, rem)

-- | @flattenArray k flat arr@ flattens the outer @k@ dimensions of
-- @arr@ to @flat@.  (Make sure @flat@ is the sum of those dimensions
-- or you'll have a bad time.)
flattenArray :: Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray :: forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray Int
k TV Int64
flat VName
arr = do
  ArrayEntry MemLoc
arr_loc PrimType
pt <- forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  let flat_shape :: ShapeBase SubExp
flat_shape = forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
flat) forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop Int
k (MemLoc -> [SubExp]
memLocShape MemLoc
arr_loc)
  forall rep r op.
[Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_flat") PrimType
pt ShapeBase SubExp
flat_shape (MemLoc -> VName
memLocName MemLoc
arr_loc) forall a b. (a -> b) -> a -> b
$
    forall a. a -> Maybe a -> a
fromMaybe (forall a. HasCallStack => [Char] -> a
error [Char]
"flattenArray") forall a b. (a -> b) -> a -> b
$
      forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
LMAD.reshape (MemLoc -> LMAD
memLocLMAD MemLoc
arr_loc) (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
flat_shape)

sliceArray :: Imp.TExp Int64 -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray :: forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray TPrimExp Int64 VName
start TV Int64
size VName
arr = do
  MemLoc VName
mem [SubExp]
_ LMAD
ixfun <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  TypeBase (ShapeBase SubExp) NoUniqueness
arr_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
arr
  let slice :: Slice (TPrimExp Int64 VName)
slice =
        forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum
          (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
arr_t))
          [forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
start (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
size) TPrimExp Int64 VName
1]
  forall rep r op.
[Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray
    (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_chunk")
    (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
arr_t)
    (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
arr_t forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Int64
size))
    VName
mem
    forall a b. (a -> b) -> a -> b
$ forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD
ixfun Slice (TPrimExp Int64 VName)
slice

-- | @applyLambda lam dests args@ emits code that:
--
-- 1. Binds each parameter of @lam@ to the corresponding element of
--    @args@, interpreted as a (name,slice) pair (as in 'copyDWIM').
--    Use an empty list for a scalar.
--
-- 2. Executes the body of @lam@.
--
-- 3. Binds the t'SubExp's that are the 'Result' of @lam@ to the
-- provided @dest@s, again interpreted as the destination for a
-- 'copyDWIM'.
applyLambda ::
  (Mem rep inner) =>
  Lambda rep ->
  [(VName, [DimIndex (Imp.TExp Int64)])] ->
  [(SubExp, [DimIndex (Imp.TExp Int64)])] ->
  ImpM rep r op ()
applyLambda :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
  forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args) forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
arg, [DimIndex (TPrimExp Int64 VName)]
arg_slice)) ->
    forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
arg [DimIndex (TPrimExp Int64 VName)]
arg_slice
  forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ do
    let res :: [SubExp]
res = forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [SubExp]
res) forall a b. (a -> b) -> a -> b
$ \((VName
dest, [DimIndex (TPrimExp Int64 VName)]
dest_slice), SubExp
se) ->
      forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TPrimExp Int64 VName)]
dest_slice SubExp
se []

-- | As applyLambda, but first rename the names in the lambda.  This
-- makes it safe to apply it in multiple places.  (It might be safe
-- anyway, but you have to be more careful - use this if you are in
-- doubt.)
applyRenamedLambda ::
  (Mem rep inner) =>
  Lambda rep ->
  [(VName, [DimIndex (Imp.TExp Int64)])] ->
  [(SubExp, [DimIndex (Imp.TExp Int64)])] ->
  ImpM rep r op ()
applyRenamedLambda :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
  Lambda rep
lam_renamed <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
  forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam_renamed [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args

groupChunkLoop ::
  Imp.TExp Int32 ->
  (Imp.TExp Int32 -> TV Int64 -> InKernelGen ()) ->
  InKernelGen ()
groupChunkLoop :: TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w TExp Int32 -> TV Int64 -> InKernelGen ()
m = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  let max_chunk_size :: TExp Int32
max_chunk_size = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  TExp Int32
num_chunks <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" forall a b. (a -> b) -> a -> b
$ TExp Int32
w forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
max_chunk_size
  forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp Int32
num_chunks forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
    TExp Int32
chunk_start <-
      forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"chunk_start" forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i forall a. Num a => a -> a -> a
* TExp Int32
max_chunk_size
    TExp Int32
chunk_end <-
      forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"chunk_end" forall a b. (a -> b) -> a -> b
$ forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMin32 TExp Int32
w (TExp Int32
chunk_start forall a. Num a => a -> a -> a
+ TExp Int32
max_chunk_size)
    TV Int64
chunk_size <-
      forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"chunk_size" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_end forall a. Num a => a -> a -> a
- TExp Int32
chunk_start
    TExp Int32 -> TV Int64 -> InKernelGen ()
m TExp Int32
chunk_start TV Int64
chunk_size

virtualisedGroupScan ::
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
virtualisedGroupScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
    let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
        crosses_segment :: TExp Bool
crosses_segment =
          case Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag of
            Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing -> forall v. TPrimExp Bool v
false
            Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true ->
              TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int32
chunk_start forall a. Num a => a -> a -> a
- TExp Int32
1)) (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int32
chunk_start)
    forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"possibly incorporate carry" forall a b. (a -> b) -> a -> b
$
      forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0 forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
crosses_segment) forall a b. (a -> b) -> a -> b
$ do
        TPrimExp Int64 VName
carry_idx <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"carry_idx" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
        forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
          Lambda GPUMem
lam
          (forall a b. (a -> b) -> [a] -> [b]
map (,[forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) [VName]
arrs)
          ( forall a b. (a -> b) -> [a] -> [b]
map ((,[forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
carry_idx]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
arrs
              forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map ((,[forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
arrs
          )

    [VName]
arrs_chunks <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
arrs

    forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

    Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
w) (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_size) Lambda GPUMem
lam [VName]
arrs_chunks

copyInGroup :: CopyCompiler GPUMem KernelEnv Imp.KernelOp
copyInGroup :: CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup PrimType
pt MemLoc
destloc MemLoc
srcloc = do
  Space
dest_space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
destloc)
  Space
src_space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
srcloc)

  let src_lmad :: LMAD
src_lmad = MemLoc -> LMAD
memLocLMAD MemLoc
srcloc
      dims :: [TPrimExp Int64 VName]
dims = forall num. LMAD num -> Shape num
LMAD.shape LMAD
src_lmad
      rank :: Int
rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims

  case (Space
dest_space, Space
src_space) of
    (ScalarSpace [SubExp]
destds PrimType
_, ScalarSpace [SubExp]
srcds PrimType
_) -> do
      let fullDim :: d -> DimIndex d
fullDim d
d = forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
          destslice' :: Slice (TPrimExp Int64 VName)
destslice' =
            forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
              forall a. Int -> a -> [a]
replicate (Int
rank forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
                forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
          srcslice' :: Slice (TPrimExp Int64 VName)
srcslice' =
            forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
              forall a. Int -> a -> [a]
replicate (Int
rank forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
                forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
      forall rep r op. CopyCompiler rep r op
lmadCopy
        PrimType
pt
        (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc Slice (TPrimExp Int64 VName)
destslice')
        (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc Slice (TPrimExp Int64 VName)
srcslice')
    (Space, Space)
_ -> do
      forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims) forall a b. (a -> b) -> a -> b
$ \[TExp Int32]
is ->
        forall rep r op. CopyCompiler rep r op
lmadCopy
          PrimType
pt
          (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
          (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

localThreadIDs :: [SubExp] -> InKernelGen [Imp.TExp Int64]
localThreadIDs :: [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims = do
  TPrimExp Int64 VName
ltid <- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  let dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
  forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall rep r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"ltid" [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
ltid) (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep r op. ImpM rep r op r
askEnv

partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims (SegSeqDims [Int]
seq_is) SegSpace
space =
  forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
    forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
seq_is) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) (forall a b. [a] -> [b] -> [(a, b)]
zip (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) [Int
0 ..])

compileFlatId :: SegSpace -> InKernelGen ()
compileFlatId :: SegSpace -> InKernelGen ()
compileFlatId SegSpace
space = do
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) TExp Int32
ltid

-- Construct the necessary lock arrays for an intra-group histogram.
prepareIntraGroupSegHist ::
  Shape ->
  Count GroupSize SubExp ->
  [HistOp GPUMem] ->
  InKernelGen [[Imp.TExp Int64] -> InKernelGen ()]
prepareIntraGroupSegHist :: ShapeBase SubExp
-> Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist ShapeBase SubExp
segments Count GroupSize SubExp
group_size =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM Maybe Locking
-> HistOp GPUMem
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp forall a. Maybe a
Nothing
  where
    onOp :: Maybe Locking
-> HistOp GPUMem
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp Maybe Locking
l HistOp GPUMem
op = do
      KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
      AtomicBinOp
atomicBinOp <- KernelEnv -> AtomicBinOp
kernelAtomics forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv

      let local_subhistos :: [VName]
local_subhistos = forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op

      case (Maybe Locking
l, AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp forall a b. (a -> b) -> a -> b
$ forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) of
        (Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
        (Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
        (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
        (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
          VName
locks <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"locks"

          let num_locks :: TPrimExp Int64 VName
num_locks = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
              dims :: [TPrimExp Int64 VName]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims (ShapeBase SubExp
segments forall a. Semigroup a => a -> a -> a
<> forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp GPUMem
op forall a. Semigroup a => a -> a -> a
<> forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp GPUMem
op)
              l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
num_locks) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims)
              locks_t :: TypeBase (ShapeBase SubExp) NoUniqueness
locks_t = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 (forall d. [d] -> ShapeBase d
Shape [forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size]) NoUniqueness
NoUniqueness

          VName
locks_mem <- forall rep r op.
[Char]
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM rep r op VName
sAlloc [Char]
"locks_mem" (TypeBase (ShapeBase SubExp) NoUniqueness
-> Count Bytes (TPrimExp Int64 VName)
typeSize TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"
          forall rep r op.
VName
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op ()
dArray VName
locks PrimType
int32 (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) VName
locks_mem forall a b. (a -> b) -> a -> b
$
            forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$
              TypeBase (ShapeBase SubExp) NoUniqueness
locks_t

          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"All locks start out unlocked" forall a b. (a -> b) -> a -> b
$
            forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants] forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is ->
              forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
locks [TPrimExp Int64 VName]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)

groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
virt SegSpace
space InKernelGen ()
m = do
  let ([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims

  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  let group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  -- Maybe we can statically detect that this is actually a
  -- SegNoVirtFull and generate ever-so-slightly simpler code.
  let virt' :: SegVirt
virt' = if [TPrimExp Int64 VName]
dims' forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName
group_size] then SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []) else SegVirt
virt
  case SegVirt
virt' of
    SegVirt
SegVirt -> do
      Maybe (TExp Int32)
iters <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] (TExp Int32)
kernelChunkItersMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
      case Maybe (TExp Int32)
iters of
        Maybe (TExp Int32)
Nothing -> do
          TExp Int32
iterations <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"iterations" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims'
          forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop TExp Int32
iterations forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
            forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
            InKernelGen ()
m
        Just TExp Int32
num_chunks -> forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$ do
          let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
          forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp Int32
num_chunks forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
            TExp Int32
i <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"i" forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
group_size forall a. Num a => a -> a -> a
+ TExp Int32
ltid
            forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
            forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> TPrimExp Int64 a
le64) [VName]
ltids)) [TPrimExp Int64 VName]
dims') InKernelGen ()
m
    SegVirt
SegNoVirt -> forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$ do
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims
      forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [SubExp]
dims) InKernelGen ()
m
    SegNoVirtFull SegSeqDims
seq_dims -> do
      let (([VName]
ltids_seq, [SubExp]
dims_seq), ([VName]
ltids_par, [SubExp]
dims_par)) =
            forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims SegSpace
space
      forall rep r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims_seq) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_seq -> do
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_seq [TPrimExp Int64 VName]
is_seq
        forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$ do
          forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_par forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims_par
          InKernelGen ()
m

compileGroupExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupExp :: ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Opaque OpaqueOp
_ SubExp
se)) =
  -- Cannot print in GPU code.
  forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
pe) [] SubExp
se []
-- The static arrays stuff does not work inside kernels.
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase (ShapeBase SubExp) NoUniqueness
_)) =
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
    forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileGroupExp Pat (LetDec GPUMem)
_ (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs)) = do
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc VName
acc [SubExp]
is [SubExp]
vs
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Replicate ShapeBase SubExp
ds SubExp
se)) | ShapeBase SubExp
ds forall a. Eq a => a -> a -> Bool
/= forall a. Monoid a => a
mempty = do
  VName
flat <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"rep_flat"
  [VName]
is <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase (ShapeBase SubExp) NoUniqueness
dest_t) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"rep_i")
  let is' :: [TPrimExp Int64 VName]
is' = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is
  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
SegVirt (VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is forall a b. (a -> b) -> a -> b
$ forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
dest_t) forall a b. (a -> b) -> a -> b
$
    forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [TPrimExp Int64 VName]
is' SubExp
se (forall a. Int -> [a] -> [a]
drop (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
ds) [TPrimExp Int64 VName]
is')
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
  where
    dest_t :: TypeBase (ShapeBase SubExp) NoUniqueness
dest_t = forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
dest
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
it)) = do
  Exp
n' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
n
  Exp
e' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  Exp
s' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
s
  forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp Exp
n') forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i' -> do
    TV Any
x <-
      forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"x" forall a b. (a -> b) -> a -> b
$
        forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$
          forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' forall a b. (a -> b) -> a -> b
$
            forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
i') Exp
s'
    forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
dest) [TPrimExp Int64 VName
i'] (VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Any
x)) []
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

-- When generating code for a scalar in-place update, we must make
-- sure that only one thread performs the write.  When writing an
-- array, the group-level copy code will take care of doing the right
-- thing.
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se))
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice = do
      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
      TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
      forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
        case Safety
safety of
          Safety
Unsafe -> InKernelGen ()
write
          Safety
Safe -> forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds Slice (TPrimExp Int64 VName)
slice' [TPrimExp Int64 VName]
dims) InKernelGen ()
write
      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
  where
    slice' :: Slice (TPrimExp Int64 VName)
slice' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice
    dims :: [TPrimExp Int64 VName]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
    write :: InKernelGen ()
write = forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
pe) (forall d. Slice d -> [DimIndex d]
unSlice Slice (TPrimExp Int64 VName)
slice') SubExp
se []
compileGroupExp Pat (LetDec GPUMem)
dest Exp GPUMem
e = do
  -- It is a messy to jump into control flow for error handling.
  -- Avoid that by always doing an error sync here.  Potential
  -- improvement: only do this if any errors are pending (this could
  -- also be handled in later codegen).
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall {rep}. Exp rep -> Bool
doSync Exp GPUMem
e) forall a b. (a -> b) -> a -> b
$ forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
  forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e
  where
    doSync :: Exp rep -> Bool
doSync Loop {} = Bool
True
    doSync Match {} = Bool
True
    doSync Exp rep
_ = Bool
False

compileGroupOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupOp :: OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp Pat (LetDec GPUMem)
pat (Alloc SubExp
size Space
space) =
  Pat LetDecMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pat (LetDec GPUMem)
pat SubExp
size Space
space
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegSpace -> InKernelGen ()
compileFlatId SegSpace
space

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
    forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
pat) forall a b. (a -> b) -> a -> b
$
        forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegSpace -> InKernelGen ()
compileFlatId SegSpace
space

  let ([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
    forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
        forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
          VName
dest
          (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids)
          (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
          []

  Fence
fence <- [VName] -> InKernelGen Fence
fenceForArrays forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence

  let segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
      crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
        (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)

  -- groupScan needs to treat the scan output as a one-dimensional
  -- array of scan elements, so we invent some new flattened arrays
  -- here.
  TV Int64
dims_flat <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
  let scan :: SegBinOp GPUMem
scan = forall a. [a] -> a
head [SegBinOp GPUMem]
scans
      num_scan_results :: Int
num_scan_results = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan
  [VName]
arrs_flat <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) forall a b. (a -> b) -> a -> b
$
      forall a. Int -> [a] -> [a]
take Int
num_scan_results forall a b. (a -> b) -> a -> b
$
        forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat

  case SegLevel -> SegVirt
segVirt SegLevel
lvl of
    SegVirt
SegVirt ->
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
        (forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
        (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dims_flat)
        (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
        [VName]
arrs_flat
    SegVirt
_ ->
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
        (forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
        (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
        (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
        (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
        [VName]
arrs_flat
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegSpace -> InKernelGen ()
compileFlatId SegSpace
space

  let dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      mkTempArr :: TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr TypeBase (ShapeBase SubExp) NoUniqueness
t =
        forall rep r op.
[Char]
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray [Char]
"red_arr" (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
t) (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims forall a. Semigroup a => a -> a -> a
<> forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
t) forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"

  [VName]
tmp_arrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
ops
  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
    forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) =
            forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmp_arrs [KernelResult]
red_res) forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
        forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LetDecMem]
map_pes [KernelResult]
map_res

  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

  let tmps_for_ops :: [[VName]]
tmps_for_ops = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
ops) [VName]
tmp_arrs
  case SegLevel -> SegVirt
segVirt SegLevel
lvl of
    SegVirt
SegVirt -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
    SegVirt
_ -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
  where
    ([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
    ([PatElem LetDecMem]
red_pes, [PatElem LetDecMem]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
pat

    virtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
      TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
      TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
        forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"possibly incorporate carry" forall a b. (a -> b) -> a -> b
$
          forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
              forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
                (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
                (forall a b. (a -> b) -> [a] -> [b]
map (,[forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) [VName]
tmps)
                ( forall a b. (a -> b) -> [a] -> [b]
map ((,[]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem LetDecMem]
red_pes
                    forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map ((,[forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
tmps
                )

        forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
          [VName]
tmps_chunks <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
tmps
          TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_size)) (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps_chunks

        forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

        forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
            forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]

    --
    virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
      TV Int64
dims_flat <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      let segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
          crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
            (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall e. IntegralExp e => e -> e -> e
`rem` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
        [VName]
tmps_flat <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
        Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
          (forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
          (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dims_flat)
          (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
          [VName]
tmps_flat

      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
          forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
            (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
            []
            (VName -> SubExp
Var VName
arr)
            (forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') forall a. [a] -> [a] -> [a]
++ [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [TPrimExp Int64 VName]
dims' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])

      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

    -- Nonsegmented case (or rather, a single segment) - this we can
    -- handle directly with a group-level reduction.
    nonvirtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
        TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps
      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
      forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
          forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
0]
      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

    -- Segmented intra-group reductions are turned into (regular)
    -- segmented scans.  It is possible that this can be done
    -- better, but at least this approach is simple.
    nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
      -- groupScan operates on flattened arrays.  This does not
      -- involve copying anything; merely playing with the index
      -- function.
      TV Int64
dims_flat <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      let segment_size :: TPrimExp Int64 VName
segment_size = forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
          crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
            (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall a. Num a => a -> a -> a
- forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to forall e. IntegralExp e => e -> e -> e
`rem` forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
        [VName]
tmps_flat <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
        Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
          (forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
          (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
          (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
          (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
          [VName]
tmps_flat

      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
          forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
            (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
            []
            (VName -> SubExp
Var VName
arr)
            (forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) (forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') forall a. [a] -> [a] -> [a]
++ [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [TPrimExp Int64 VName]
dims' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])

      forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
kbody))) = do
  SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
  let ([VName]
ltids, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  -- We don't need the red_pes, because it is guaranteed by our type
  -- rules that they occupy the same memory as the destinations for
  -- the ops.
  let num_red_res :: Int
num_red_res = forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
      ([PatElem LetDecMem]
_red_pes, [PatElem LetDecMem]
map_pes) =
        forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
pat

  Count GroupSize SubExp
group_size <- KernelConstants -> Count GroupSize SubExp
kernelGroupSizeCount forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  [[TPrimExp Int64 VName] -> InKernelGen ()]
ops' <- ShapeBase SubExp
-> Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [SubExp]
dims) Count GroupSize SubExp
group_size [HistOp GPUMem]
ops

  -- Ensure that all locks have been initialised.
  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space forall a b. (a -> b) -> a -> b
$
    forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
          ([SubExp]
red_is, [SubExp]
red_vs) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LetDecMem]
map_pes [KernelResult]
map_res

      let vs_per_op :: [[SubExp]]
vs_per_op = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [SubExp]
red_vs

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SubExp]
red_is [[SubExp]]
vs_per_op [[TPrimExp Int64 VName] -> InKernelGen ()]
ops' [HistOp GPUMem]
ops) forall a b. (a -> b) -> a -> b
$
        \(SubExp
bin, [SubExp]
op_vs, [TPrimExp Int64 VName] -> InKernelGen ()
do_op, HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda GPUMem
lam) -> do
          let bin' :: TPrimExp Int64 VName
bin' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
bin
              dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
              bin_in_bounds :: TExp Bool
bin_in_bounds = Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
bin']) [TPrimExp Int64 VName]
dest_shape'
              bin_is :: [TPrimExp Int64 VName]
bin_is = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 (forall a. [a] -> [a]
init [VName]
ltids) forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName
bin']
              vs_params :: [Param LetDecMem]
vs_params = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

          forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" forall a b. (a -> b) -> a -> b
$
            forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bin_in_bounds forall a b. (a -> b) -> a -> b
$ do
              forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
              forall rep r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
op_vs) forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
                  forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [TPrimExp Int64 VName]
is
                [TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bin_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)

  forall op rep r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat Op GPUMem
_ =
  forall a. [Char] -> a
compilerBugS forall a b. (a -> b) -> a -> b
$ [Char]
"compileGroupOp: cannot compile rhs of binding " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec GPUMem)
pat

groupOperations :: Operations GPUMem KernelEnv Imp.KernelOp
groupOperations :: Operations GPUMem KernelEnv KernelOp
groupOperations =
  (forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp)
    { opsCopyCompiler :: CopyCompiler GPUMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup,
      opsExpCompiler :: ExpCompiler GPUMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp,
      opsStmsCompiler :: StmsCompiler GPUMem KernelEnv KernelOp
opsStmsCompiler = \Names
_ -> forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms forall a. Monoid a => a
mempty,
      opsAllocCompilers :: Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
opsAllocCompilers =
        forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [([Char] -> Space
Space [Char]
"local", forall r. AllocCompiler GPUMem r KernelOp
allocLocal)]
    }

arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var VName
name) = do
  VarEntry GPUMem
res <- forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry GPUMem
res of
    ArrayVar Maybe (Exp GPUMem)
_ ArrayEntry
entry ->
      ([Char] -> Space
Space [Char]
"local" ==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
entry))
    VarEntry GPUMem
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
arrayInLocalMemory Constant {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

sKernelGroup ::
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernelGroup :: [Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelGroup = Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ()
sKernel Operations GPUMem KernelEnv KernelOp
groupOperations KernelConstants -> TExp Int32
kernelGroupId

compileGroupResult ::
  SegSpace ->
  PatElem LetDecMem ->
  KernelResult ->
  InKernelGen ()
compileGroupResult :: SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
_ PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp
w, SubExp
per_group_elems)] VName
what) = do
  TPrimExp Int64 VName
n <- SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
what

  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  let ltid :: TPrimExp Int64 VName
ltid = forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
      offset :: TPrimExp Int64 VName
offset =
        SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems
          forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)

  -- Avoid loop for the common case where each thread is statically
  -- known to write at most one element.
  forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
    if SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems forall a. Eq a => a -> a -> Bool
== KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
      then
        forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) forall a b. (a -> b) -> a -> b
$
          forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName
ltid forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
ltid]
      else forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" (TPrimExp Int64 VName
n forall e. IntegralExp e => e -> e -> e
`divUp` KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
        TPrimExp Int64 VName
j <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
        forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
j forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) forall a b. (a -> b) -> a -> b
$
          forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName
j forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
j]
compileGroupResult SegSpace
space PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
what) = do
  let gids :: [VName]
gids = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      out_tile_sizes :: [TPrimExp Int64 VName]
out_tile_sizes = forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(SubExp, SubExp)]
dims
      group_is :: [TPrimExp Int64 VName]
group_is = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids) [TPrimExp Int64 VName]
out_tile_sizes
  [TPrimExp Int64 VName]
local_is <- [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims
  [TV Int64]
is_for_thread <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"thread_out_index") forall a b. (a -> b) -> a -> b
$
      forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(+) [TPrimExp Int64 VName]
group_is [TPrimExp Int64 VName]
local_is

  forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
    forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k). TV t -> VName
tvVar [TV Int64]
is_for_thread) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims) forall a b. (a -> b) -> a -> b
$
      forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k). TV t -> TExp t
tvExp [TV Int64]
is_for_thread) (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
local_is
compileGroupResult SegSpace
space PatElem LetDecMem
pe (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv

  let gids :: [VName]
gids = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ([SubExp]
dims, [SubExp]
group_tiles, [SubExp]
reg_tiles) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
      group_tiles' :: [TPrimExp Int64 VName]
group_tiles' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
group_tiles
      reg_tiles' :: [TPrimExp Int64 VName]
reg_tiles' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
reg_tiles

  -- Which group tile is this group responsible for?
  let group_tile_is :: [TPrimExp Int64 VName]
group_tile_is = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids

  -- Within the group tile, which register tile is this thread
  -- responsible for?
  [TPrimExp Int64 VName]
reg_tile_is <-
    forall rep r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"reg_tile_i" [TPrimExp Int64 VName]
group_tiles' forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants

  -- Compute output array slice for the register tile belonging to
  -- this thread.
  let regTileSliceDim :: (TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim (TExp t
group_tile, TExp t
group_tile_i) (TExp t
reg_tile, TExp t
reg_tile_i) = do
        TExp t
tile_dim_start <-
          forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"tile_dim_start" forall a b. (a -> b) -> a -> b
$
            TExp t
reg_tile forall a. Num a => a -> a -> a
* (TExp t
group_tile forall a. Num a => a -> a -> a
* TExp t
group_tile_i forall a. Num a => a -> a -> a
+ TExp t
reg_tile_i)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d. d -> d -> d -> DimIndex d
DimSlice TExp t
tile_dim_start TExp t
reg_tile TExp t
1
  Slice (TPrimExp Int64 VName)
reg_tile_slices <-
    forall d. [DimIndex d] -> Slice d
Slice
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
        forall {k} {t :: k} {rep} {r} {op}.
NumExp t =>
(TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim
        (forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
group_tiles' [TPrimExp Int64 VName]
group_tile_is)
        (forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
reg_tiles' [TPrimExp Int64 VName]
reg_tile_is)

  forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
    forall rep r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (forall d. [d] -> ShapeBase d
Shape [SubExp]
reg_tiles) forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_in_reg_tile -> do
      let dest_is :: [TPrimExp Int64 VName]
dest_is = forall d. Num d => Slice d -> [d] -> [d]
fixSlice Slice (TPrimExp Int64 VName)
reg_tile_slices [TPrimExp Int64 VName]
is_in_reg_tile
          src_is :: [TPrimExp Int64 VName]
src_is = [TPrimExp Int64 VName]
reg_tile_is forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is_in_reg_tile
      forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) [TPrimExp Int64 VName]
dest_is forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims) forall a b. (a -> b) -> a -> b
$
        forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
dest_is (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
src_is
compileGroupResult SegSpace
space PatElem LetDecMem
pe (Returns ResultManifest
_ Certs
_ SubExp
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  Bool
in_local_memory <- SubExp -> InKernelGen Bool
arrayInLocalMemory SubExp
what
  let gids :: [TPrimExp Int64 VName]
gids = forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> TPrimExp Int64 a
Imp.le64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  if Bool -> Bool
not Bool
in_local_memory
    then
      forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations forall a b. (a -> b) -> a -> b
$
        forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) forall a b. (a -> b) -> a -> b
$
          forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
    else -- If the result of the group is an array in local memory, we
    -- store it by collective copying among all the threads of the
    -- group.  TODO: also do this if the array is in global memory
    -- (but this is a bit more tricky, synchronisation-wise).
      forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
compileGroupResult SegSpace
_ PatElem LetDecMem
_ WriteReturns {} =
  forall a. [Char] -> a
compilerLimitationS [Char]
"compileGroupResult: WriteReturns not handled yet."

-- | The sizes of nested iteration spaces in the kernel.
type SegOpSizes = S.Set [SubExp]

-- | Various useful precomputed information for group-level SegOps.
data Precomputed = Precomputed
  { Precomputed -> SegOpSizes
pcSegOpSizes :: SegOpSizes,
    Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32)
  }

-- | Find the sizes of nested parallelism in a t'SegOp' body.
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes = Stms GPUMem -> SegOpSizes
onStms
  where
    onStms :: Stms GPUMem -> SegOpSizes
onStms = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> SegOpSizes
onStm
    onStm :: Stm GPUMem -> SegOpSizes
onStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp SegOp SegLevel GPUMem
op)))) =
      case SegLevel -> SegVirt
segVirt forall a b. (a -> b) -> a -> b
$ forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op of
        SegNoVirtFull SegSeqDims
seq_dims ->
          forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims forall a b. (a -> b) -> a -> b
$ forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
        SegVirt
_ -> forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace forall a b. (a -> b) -> a -> b
$ forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
    onStm (Let (Pat [PatElem (LetDec GPUMem)
pe]) StmAux (ExpDec GPUMem)
_ (BasicOp (Replicate {}))) =
      forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
    onStm (Let (Pat [PatElem (LetDec GPUMem)
pe]) StmAux (ExpDec GPUMem)
_ (BasicOp (Iota {}))) =
      forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
    onStm (Let (Pat [PatElem (LetDec GPUMem)
pe]) StmAux (ExpDec GPUMem)
_ (BasicOp (Manifest {}))) =
      forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
    onStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_)) =
      forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Stms GPUMem -> SegOpSizes
onStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPUMem)]
cases forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> SegOpSizes
onStms (forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
defbody)
    onStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Loop [(FParam GPUMem, SubExp)]
_ LoopForm
_ Body GPUMem
body)) =
      Stms GPUMem -> SegOpSizes
onStms (forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body)
    onStm Stm GPUMem
_ = forall a. Monoid a => a
mempty

-- | Precompute various constants and useful information.
precomputeConstants :: Count GroupSize (Imp.TExp Int64) -> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants :: Count GroupSize (TPrimExp Int64 VName)
-> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants Count GroupSize (TPrimExp Int64 VName)
group_size Stms GPUMem
stms = do
  let sizes :: SegOpSizes
sizes = Stms GPUMem -> SegOpSizes
segOpSizes Stms GPUMem
stms
  Map [SubExp] (TExp Int32)
iters_map <- forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap (forall a. Set a -> [a]
S.toList SegOpSizes
sizes)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SegOpSizes -> Map [SubExp] (TExp Int32) -> Precomputed
Precomputed SegOpSizes
sizes Map [SubExp] (TExp Int32)
iters_map
  where
    mkMap :: [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap [SubExp]
dims = do
      let n :: TPrimExp Int64 VName
n = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 [SubExp]
dims
      TExp Int32
num_chunks <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
n forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, TExp Int32
num_chunks)

-- | Make use of various precomputed constants.
precomputedConstants :: Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants :: forall a. Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants Precomputed
pre InKernelGen a
m = do
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv
  Map [SubExp] [TExp Int32]
new_ids <- forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {t :: k} {rep} {r} {op}.
IntExp t =>
TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TExp Int32
ltid) (forall a. Set a -> [a]
S.toList (Precomputed -> SegOpSizes
pcSegOpSizes Precomputed
pre))
  let f :: KernelEnv -> KernelEnv
f KernelEnv
env =
        KernelEnv
env
          { kernelConstants :: KernelConstants
kernelConstants =
              (KernelEnv -> KernelConstants
kernelConstants KernelEnv
env)
                { kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
new_ids,
                  kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap Precomputed
pre
                }
          }
  forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv KernelEnv -> KernelEnv
f InKernelGen a
m
  where
    mkMap :: TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TPrimExp t VName
ltid [SubExp]
dims = do
      let dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      [TPrimExp Int64 VName]
ids' <- forall rep r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"ltid_pre" [TPrimExp Int64 VName]
dims' (forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp t VName
ltid)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
ids')