-- | Carefully optimised implementations of GPU transpositions.
-- Written in ImpCode so we can compile it to both CUDA and OpenCL.
module Futhark.CodeGen.ImpGen.Kernels.Transpose
  ( TransposeType(..)
  , TransposeArgs
  , mapTransposeKernel
  )
  where

import Prelude hiding (quot, rem)

import Futhark.CodeGen.ImpCode.Kernels
import Futhark.IR.Prop.Types
import Futhark.Util.IntegralExp (IntegralExp, divUp, quot, rem)

-- | Which form of transposition to generate code for.
data TransposeType = TransposeNormal
                   | TransposeLowWidth
                   | TransposeLowHeight
                   | TransposeSmall -- ^ For small arrays that do not
                                    -- benefit from coalescing.
                   deriving (TransposeType -> TransposeType -> Bool
(TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool) -> Eq TransposeType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransposeType -> TransposeType -> Bool
$c/= :: TransposeType -> TransposeType -> Bool
== :: TransposeType -> TransposeType -> Bool
$c== :: TransposeType -> TransposeType -> Bool
Eq, Eq TransposeType
Eq TransposeType
-> (TransposeType -> TransposeType -> Ordering)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> Bool)
-> (TransposeType -> TransposeType -> TransposeType)
-> (TransposeType -> TransposeType -> TransposeType)
-> Ord TransposeType
TransposeType -> TransposeType -> Bool
TransposeType -> TransposeType -> Ordering
TransposeType -> TransposeType -> TransposeType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TransposeType -> TransposeType -> TransposeType
$cmin :: TransposeType -> TransposeType -> TransposeType
max :: TransposeType -> TransposeType -> TransposeType
$cmax :: TransposeType -> TransposeType -> TransposeType
>= :: TransposeType -> TransposeType -> Bool
$c>= :: TransposeType -> TransposeType -> Bool
> :: TransposeType -> TransposeType -> Bool
$c> :: TransposeType -> TransposeType -> Bool
<= :: TransposeType -> TransposeType -> Bool
$c<= :: TransposeType -> TransposeType -> Bool
< :: TransposeType -> TransposeType -> Bool
$c< :: TransposeType -> TransposeType -> Bool
compare :: TransposeType -> TransposeType -> Ordering
$ccompare :: TransposeType -> TransposeType -> Ordering
$cp1Ord :: Eq TransposeType
Ord, Int -> TransposeType -> ShowS
[TransposeType] -> ShowS
TransposeType -> String
(Int -> TransposeType -> ShowS)
-> (TransposeType -> String)
-> ([TransposeType] -> ShowS)
-> Show TransposeType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransposeType] -> ShowS
$cshowList :: [TransposeType] -> ShowS
show :: TransposeType -> String
$cshow :: TransposeType -> String
showsPrec :: Int -> TransposeType -> ShowS
$cshowsPrec :: Int -> TransposeType -> ShowS
Show)

-- | The types of the arguments accepted by a transposition function.
type TransposeArgs = (VName, Exp,
                      VName, Exp,
                      Exp, Exp, Exp, Exp,
                      Exp, Exp, Exp,
                      VName)

elemsPerThread :: IntegralExp a => a
elemsPerThread :: a
elemsPerThread = a
4

mapTranspose :: Exp -> TransposeArgs -> PrimType -> TransposeType -> KernelCode
mapTranspose :: Exp -> TransposeArgs -> PrimType -> TransposeType -> KernelCode
mapTranspose Exp
block_dim TransposeArgs
args PrimType
t TransposeType
kind =
  case TransposeType
kind of
    TransposeType
TransposeSmall ->
      [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
      [ KernelCode
get_ids

      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
our_array_offset (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_global_id_0 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` (Exp
heightExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
width) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp
heightExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
width)

      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
x_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ (VName -> Exp
v32 VName
get_global_id_0 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` (Exp
heightExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
width)) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
height
      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
y_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_global_id_0 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
height

      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
odata_offset (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
        (Exp
basic_odata_offset Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
t) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
our_array_offset
      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
idata_offset (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
        (Exp
basic_idata_offset Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
t) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
our_array_offset

      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
index_in (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
y_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
width Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
x_index
      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
index_out (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
x_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
height Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
y_index

      , Exp -> KernelCode -> KernelCode -> KernelCode
forall a. Exp -> Code a -> Code a -> Code a
If (VName -> Exp
v32 VName
get_global_id_0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
input_size)
        (VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
odata (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
odata_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
index_out) PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
         VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
index VName
idata (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
idata_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
index_in) PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile)
        KernelCode
forall a. Monoid a => a
mempty
      ]

    TransposeType
TransposeLowWidth ->
      KernelCode -> KernelCode
mkTranspose (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp -> Exp -> KernelCode
lowDimBody
      (VName -> Exp
v32 VName
get_group_id_0 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ (VName -> Exp
v32 VName
get_local_id_0 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
muly))
      (VName -> Exp
v32 VName
get_group_id_1 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
muly Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_1 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
       (VName -> Exp
v32 VName
get_local_id_0 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
muly) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim)
      (VName -> Exp
v32 VName
get_group_id_1Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
muly Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_0 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
       (VName -> Exp
v32 VName
get_local_id_1 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
muly) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim)
      (VName -> Exp
v32 VName
get_group_id_0 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ (VName -> Exp
v32 VName
get_local_id_1 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
muly))

    TransposeType
TransposeLowHeight ->
      KernelCode -> KernelCode
mkTranspose (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp -> Exp -> KernelCode
lowDimBody
      (VName -> Exp
v32 VName
get_group_id_0 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
mulx Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_0 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
       (VName -> Exp
v32 VName
get_local_id_1 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
mulx) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim)
      (VName -> Exp
v32 VName
get_group_id_1 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ (VName -> Exp
v32 VName
get_local_id_1 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
mulx))
      (VName -> Exp
v32 VName
get_group_id_1 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ (VName -> Exp
v32 VName
get_local_id_0 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
mulx))
      (VName -> Exp
v32 VName
get_group_id_0 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
mulx Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_1 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+
       (VName -> Exp
v32 VName
get_local_id_0 Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
mulx) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim)

    TransposeType
TransposeNormal ->
      KernelCode -> KernelCode
mkTranspose (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$ [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
      [ VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
x_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_global_id_0
      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
y_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_group_id_1 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
tile_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_1
      , Exp -> KernelCode -> KernelCode
forall a. Exp -> Code a -> Code a
when (VName -> Exp
v32 VName
x_index Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
width) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
        VName -> IntType -> Exp -> KernelCode -> KernelCode
forall a. VName -> IntType -> Exp -> Code a -> Code a
For VName
j IntType
Int32 Exp
forall a. IntegralExp a => a
elemsPerThread (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
        let i :: Exp
i = VName -> Exp
v32 VName
j Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp
tile_dim Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
forall a. IntegralExp a => a
elemsPerThread)
        in [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat [ VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
index_in (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ (VName -> Exp
v32 VName
y_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
width Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
x_index
                   , Exp -> KernelCode -> KernelCode
forall a. Exp -> Code a -> Code a
when (VName -> Exp
v32 VName
y_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
height Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                           VName -> Exp
v32 VName
index_in Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
input_size) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
                     VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
block (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ (VName -> Exp
v32 VName
get_local_id_1 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp
tile_dimExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
1)
                                             Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_0)
                     PrimType
t (String -> Space
Space String
"local") Volatility
Nonvolatile (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
                     VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
index VName
idata (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
idata_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
index_in)
                     PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile]
      , KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Barrier Fence
FenceLocal
      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
x_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_group_id_1 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
tile_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_0
      , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
y_index (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_group_id_0 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
tile_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_1
      , Exp -> KernelCode -> KernelCode
forall a. Exp -> Code a -> Code a
when (VName -> Exp
v32 VName
x_index Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
height) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
        VName -> IntType -> Exp -> KernelCode -> KernelCode
forall a. VName -> IntType -> Exp -> Code a -> Code a
For VName
j IntType
Int32 Exp
forall a. IntegralExp a => a
elemsPerThread (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
        let i :: Exp
i = VName -> Exp
v32 VName
j Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp
tile_dim Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
forall a. IntegralExp a => a
elemsPerThread)
        in [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat [ VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
index_out (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ (VName -> Exp
v32 VName
y_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
height Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
x_index
                   , Exp -> KernelCode -> KernelCode
forall a. Exp -> Code a -> Code a
when (VName -> Exp
v32 VName
y_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
width Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                           VName -> Exp
v32 VName
index_out Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
output_size) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
                     VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
odata (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
odata_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
index_out)
                     PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
                     VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
index VName
block (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_local_id_0 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp
tile_dimExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
1)
                                             Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_1Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
i)
                     PrimType
t (String -> Space
Space String
"local") Volatility
Nonvolatile
                   ]
      ]

  where dec :: VName -> Exp -> Code a
dec VName
v Exp
e = VName -> Volatility -> PrimType -> Code a
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
v Volatility
Nonvolatile PrimType
int32 Code a -> Code a -> Code a
forall a. Semigroup a => a -> a -> a
<> VName -> Exp -> Code a
forall a. VName -> Exp -> Code a
SetScalar VName
v Exp
e
        v32 :: VName -> Exp
v32 = (VName -> PrimType -> Exp) -> PrimType -> VName -> Exp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> PrimType -> Exp
var PrimType
int32
        tile_dim :: Exp
tile_dim = Exp
2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim

        when :: Exp -> Code a -> Code a
when Exp
a Code a
b = Exp -> Code a -> Code a -> Code a
forall a. Exp -> Code a -> Code a -> Code a
If Exp
a Code a
b Code a
forall a. Monoid a => a
mempty

        (VName
odata, Exp
basic_odata_offset, VName
idata, Exp
basic_idata_offset,
         Exp
width, Exp
height, Exp
input_size, Exp
output_size,
         Exp
mulx, Exp
muly, Exp
_num_arrays, VName
block) = TransposeArgs
args

        -- Be extremely careful when editing this list to ensure that
        -- the names match up.  Also, be careful that the tags on
        -- these names do not conflict with the tags of the
        -- surrounding code.  We accomplish the latter by using very
        -- low tags (normal variables start at least in the low
        -- hundreds).
        [   VName
our_array_offset , VName
x_index , VName
y_index
          , VName
odata_offset, VName
idata_offset, VName
index_in, VName
index_out
          , VName
get_global_id_0
          , VName
get_local_id_0, VName
get_local_id_1
          , VName
get_group_id_0, VName
get_group_id_1, VName
get_group_id_2
          , VName
j] =
          (Int -> Name -> VName) -> [Int] -> [Name] -> [VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((Name -> Int -> VName) -> Int -> Name -> VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip Name -> Int -> VName
VName) [Int
30..] ([Name] -> [VName]) -> [Name] -> [VName]
forall a b. (a -> b) -> a -> b
$ (String -> Name) -> [String] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map String -> Name
nameFromString
          [ String
"our_array_offset" , String
"x_index" , String
"y_index"
          , String
"odata_offset", String
"idata_offset", String
"index_in", String
"index_out"
          , String
"get_global_id_0"
          , String
"get_local_id_0", String
"get_local_id_1"
          , String
"get_group_id_0", String
"get_group_id_1", String
"get_group_id_2"
          , String
"j"]

        get_ids :: KernelCode
get_ids =
          [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat [ VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_global_id_0 Volatility
Nonvolatile PrimType
int32
                  , KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGlobalId VName
get_global_id_0 Int
0
                  , VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_local_id_0 Volatility
Nonvolatile PrimType
int32
                  , KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetLocalId VName
get_local_id_0 Int
0
                  , VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_local_id_1 Volatility
Nonvolatile PrimType
int32
                  , KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetLocalId VName
get_local_id_1 Int
1
                  , VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_group_id_0 Volatility
Nonvolatile PrimType
int32
                  , KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGroupId VName
get_group_id_0 Int
0
                  , VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_group_id_1 Volatility
Nonvolatile PrimType
int32
                  , KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGroupId VName
get_group_id_1 Int
1
                  , VName -> Volatility -> PrimType -> KernelCode
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
get_group_id_2 Volatility
Nonvolatile PrimType
int32
                  , KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
GetGroupId VName
get_group_id_2 Int
2
                  ]

        mkTranspose :: KernelCode -> KernelCode
mkTranspose KernelCode
body =
          [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
          [ KernelCode
get_ids
          , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
our_array_offset (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_group_id_2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
width Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
height
          , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
odata_offset (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
            (Exp
basic_odata_offset Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
t) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
our_array_offset
          , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
idata_offset (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
            (Exp
basic_idata_offset Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
t) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
our_array_offset
          , KernelCode
body
          ]

        lowDimBody :: Exp -> Exp -> Exp -> Exp -> KernelCode
lowDimBody Exp
x_in_index Exp
y_in_index Exp
x_out_index Exp
y_out_index =
          [KernelCode] -> KernelCode
forall a. Monoid a => [a] -> a
mconcat
          [ VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
x_index Exp
x_in_index
          , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
y_index Exp
y_in_index
          , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
index_in (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
y_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
width Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
x_index
          , Exp -> KernelCode -> KernelCode
forall a. Exp -> Code a -> Code a
when (VName -> Exp
v32 VName
x_index Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
width Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. VName -> Exp
v32 VName
y_index Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
height Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. VName -> Exp
v32 VName
index_in Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
input_size) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
            VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
block (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_local_id_1 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp
block_dimExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_0)
            PrimType
t (String -> Space
Space String
"local") Volatility
Nonvolatile (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
            VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
index VName
idata (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
idata_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
index_in)
            PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile
          , KernelOp -> KernelCode
forall a. a -> Code a
Op (KernelOp -> KernelCode) -> KernelOp -> KernelCode
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Barrier Fence
FenceLocal
          , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
x_index Exp
x_out_index
          , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
SetScalar VName
y_index Exp
y_out_index
          , VName -> Exp -> KernelCode
forall a. VName -> Exp -> Code a
dec VName
index_out (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
y_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
height Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
x_index
          , Exp -> KernelCode -> KernelCode
forall a. Exp -> Code a -> Code a
when (VName -> Exp
v32 VName
x_index Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
height Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. VName -> Exp
v32 VName
y_index Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
width Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. VName -> Exp
v32 VName
index_out Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
output_size) (KernelCode -> KernelCode) -> KernelCode -> KernelCode
forall a b. (a -> b) -> a -> b
$
            VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> KernelCode
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
odata (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
odata_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
index_out)
            PrimType
t (String -> Space
Space String
"global") Volatility
Nonvolatile (Exp -> KernelCode) -> Exp -> KernelCode
forall a b. (a -> b) -> a -> b
$
            VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
index VName
block (Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
v32 VName
get_local_id_0 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* (Exp
block_dimExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
v32 VName
get_local_id_1)
            PrimType
t (String -> Space
Space String
"local") Volatility
Nonvolatile
          ]

-- | Generate a transpose kernel.  There is special support to handle
-- input arrays with low width, low height, or both.
--
-- Normally when transposing a @[2][n]@ array we would use a @FUT_BLOCK_DIM x
-- FUT_BLOCK_DIM@ group to process a @[2][FUT_BLOCK_DIM]@ slice of the input
-- array. This would mean that many of the threads in a group would be inactive.
-- We try to remedy this by using a special kernel that will process a larger
-- part of the input, by using more complex indexing. In our example, we could
-- use all threads in a group if we are processing @(2/FUT_BLOCK_DIM)@ as large
-- a slice of each rows per group. The variable @mulx@ contains this factor for
-- the kernel to handle input arrays with low height.
--
-- See issue #308 on GitHub for more details.
--
-- These kernels are optimized to ensure all global reads and writes
-- are coalesced, and to avoid bank conflicts in shared memory.  Each
-- thread group transposes a 2D tile of block_dim*2 by block_dim*2
-- elements. The size of a thread group is block_dim/2 by
-- block_dim*2, meaning that each thread will process 4 elements in a
-- 2D tile.  The shared memory array containing the 2D tile consists
-- of block_dim*2 by block_dim*2+1 elements. Padding each row with
-- an additional element prevents bank conflicts from occuring when
-- the tile is accessed column-wise.
--
-- Note that input_size and output_size may not equal width*height if
-- we are dealing with a truncated array - this happens sometimes for
-- coalescing optimisations.
mapTransposeKernel :: String -> Integer -> TransposeArgs -> PrimType -> TransposeType
                   -> Kernel
mapTransposeKernel :: String
-> Integer -> TransposeArgs -> PrimType -> TransposeType -> Kernel
mapTransposeKernel String
desc Integer
block_dim_int TransposeArgs
args PrimType
t TransposeType
kind =
  Kernel :: KernelCode
-> [KernelUse] -> [Exp] -> [Exp] -> Name -> Bool -> Kernel
Kernel
  { kernelBody :: KernelCode
kernelBody = VName -> Space -> KernelCode
forall a. VName -> Space -> Code a
DeclareMem VName
block (String -> Space
Space String
"local") KernelCode -> KernelCode -> KernelCode
forall a. Semigroup a => a -> a -> a
<>
                 KernelOp -> KernelCode
forall a. a -> Code a
Op (VName -> Count Bytes Exp -> KernelOp
LocalAlloc VName
block Count Bytes Exp
block_size) KernelCode -> KernelCode -> KernelCode
forall a. Semigroup a => a -> a -> a
<>
                 Exp -> TransposeArgs -> PrimType -> TransposeType -> KernelCode
mapTranspose Exp
block_dim TransposeArgs
args PrimType
t TransposeType
kind
  , kernelUses :: [KernelUse]
kernelUses = [KernelUse]
uses
  , kernelNumGroups :: [Exp]
kernelNumGroups = [Exp]
num_groups
  , kernelGroupSize :: [Exp]
kernelGroupSize = [Exp]
group_size
  , kernelName :: Name
kernelName = String -> Name
nameFromString String
name
  , kernelFailureTolerant :: Bool
kernelFailureTolerant = Bool
True
  }
  where pad2DBytes :: a -> a
pad2DBytes a
k = a
k a -> a -> a
forall a. Num a => a -> a -> a
* (a
k a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) a -> a -> a
forall a. Num a => a -> a -> a
* PrimType -> a
forall a. Num a => PrimType -> a
primByteSize PrimType
t
        block_size :: Count Bytes Exp
block_size =
          case TransposeType
kind of TransposeType
TransposeSmall -> Count Bytes Exp
1 -- Not used, but AMD's
                                           -- OpenCL does not like
                                           -- zero-size local memory.
                       TransposeType
TransposeNormal -> Integer -> Count Bytes Exp
forall a. Num a => Integer -> a
fromInteger (Integer -> Count Bytes Exp) -> Integer -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall a. Num a => a -> a
pad2DBytes (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Integer
2Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
block_dim_int
                       TransposeType
TransposeLowWidth -> Integer -> Count Bytes Exp
forall a. Num a => Integer -> a
fromInteger (Integer -> Count Bytes Exp) -> Integer -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall a. Num a => a -> a
pad2DBytes Integer
block_dim_int
                       TransposeType
TransposeLowHeight -> Integer -> Count Bytes Exp
forall a. Num a => Integer -> a
fromInteger (Integer -> Count Bytes Exp) -> Integer -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
forall a. Num a => a -> a
pad2DBytes Integer
block_dim_int
        block_dim :: Exp
block_dim = Integer -> Exp
forall a. Num a => Integer -> a
fromInteger Integer
block_dim_int

        (VName
odata, Exp
basic_odata_offset, VName
idata, Exp
basic_idata_offset,
         Exp
width, Exp
height, Exp
input_size, Exp
output_size,
         Exp
mulx, Exp
muly, Exp
num_arrays,
         VName
block) = TransposeArgs
args

        ([Exp]
num_groups, [Exp]
group_size) =
          case TransposeType
kind of
            TransposeType
TransposeSmall ->
              ([(Exp
num_arrays Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
width Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
height) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` (Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim)],
               [Exp
block_dim Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_dim])
            TransposeType
TransposeLowWidth ->
              Exp -> Exp -> Exp -> Exp -> ([Exp], [Exp])
lowDimKernelAndGroupSize Exp
block_dim Exp
num_arrays Exp
width (Exp -> ([Exp], [Exp])) -> Exp -> ([Exp], [Exp])
forall a b. (a -> b) -> a -> b
$ Exp
height Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
muly
            TransposeType
TransposeLowHeight ->
              Exp -> Exp -> Exp -> Exp -> ([Exp], [Exp])
lowDimKernelAndGroupSize Exp
block_dim Exp
num_arrays (Exp
width Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
mulx) Exp
height
            TransposeType
TransposeNormal ->
              let actual_dim :: Exp
actual_dim = Exp
block_dimExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
2
              in ( [ Exp
width Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
actual_dim
                   , Exp
height Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
actual_dim
                   , Exp
num_arrays]
                 , [Exp
actual_dim, Exp
actual_dim Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
forall a. IntegralExp a => a
elemsPerThread, Exp
1])

        uses :: [KernelUse]
uses = (VName -> KernelUse) -> [VName] -> [KernelUse]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> KernelUse
`ScalarUse` PrimType
int32)
               (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Exp -> Names) -> [Exp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> Names
forall a. FreeIn a => a -> Names
freeIn
                [Exp
basic_odata_offset, Exp
basic_idata_offset, Exp
num_arrays,
                 Exp
width, Exp
height, Exp
input_size, Exp
output_size, Exp
mulx, Exp
muly]) [KernelUse] -> [KernelUse] -> [KernelUse]
forall a. [a] -> [a] -> [a]
++
               (VName -> KernelUse) -> [VName] -> [KernelUse]
forall a b. (a -> b) -> [a] -> [b]
map VName -> KernelUse
MemoryUse [VName
odata, VName
idata]

        name :: String
name =
          case TransposeType
kind of TransposeType
TransposeSmall -> String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_small"
                       TransposeType
TransposeLowHeight -> String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_low_height"
                       TransposeType
TransposeLowWidth -> String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_low_width"
                       TransposeType
TransposeNormal -> String
desc

lowDimKernelAndGroupSize :: Exp -> Exp -> Exp -> Exp -> ([Exp], [Exp])
lowDimKernelAndGroupSize :: Exp -> Exp -> Exp -> Exp -> ([Exp], [Exp])
lowDimKernelAndGroupSize Exp
block_dim Exp
num_arrays Exp
x_elems Exp
y_elems =
  ([Exp
x_elems Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
block_dim,
    Exp
y_elems Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
block_dim,
    Exp
num_arrays],
   [Exp
block_dim, Exp
block_dim, Exp
1])