{-# LANGUAGE TypeFamilies #-}

-- | Perform a restricted form of block+register tiling corresponding to
--   the following pattern:
--     * a redomap is quasi-perfectly nested inside a kernel with at
--       least two parallel dimension (the perfectly nested restriction
--       is relaxed a bit to allow for SGEMM);
--     * all streamed arrays of redomap are one dimensional;
--     * all streamed arrays are variant to exacly one of the two
--       innermost parallel dimensions, and conversely for each of
--       the two innermost parallel dimensions, there is at least
--       one streamed array variant to it;
--     * the stream's result is a tuple of scalar values, which are
--       also the "thread-in-space" return of the kernel.
--     * We have further restrictions that in principle can be relaxed:
--          the redomap has exactly two array input
--          the redomap produces one scalar result
--          the kernel produces one scalar result
module Futhark.Optimise.BlkRegTiling (mmBlkRegTiling, doRegTiling3D) where

import Control.Monad.Reader
import Data.List qualified as L
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Sequence qualified as Seq
import Futhark.IR.GPU
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.TileLoops.Shared
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute

se0 :: SubExp
se0 :: SubExp
se0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0

se1 :: SubExp
se1 :: SubExp
se1 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

se2 :: SubExp
se2 :: SubExp
se2 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
2

se4 :: SubExp
se4 :: SubExp
se4 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
4

se8 :: SubExp
se8 :: SubExp
se8 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
8

scratch :: MonadBuilder m => String -> PrimType -> [SubExp] -> m VName
scratch :: forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
se_name PrimType
t [SubExp]
shape = forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
se_name forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
t [SubExp]
shape

-- | Main helper function for Register-and-Block Tiling
kkLoopBody ::
  Env ->
  ( (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
    SegLevel,
    [Int],
    (VName, SubExp, VName, SubExp, SubExp),
    (SubExp, SubExp),
    (VName, VName),
    (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
    (Lambda GPU, Lambda GPU)
  ) ->
  VName ->
  (VName, VName, VName) ->
  Bool ->
  Builder GPU [VName]
kkLoopBody :: Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
     SubExp),
    SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
    (SubExp, SubExp), (VName, VName),
    (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
    (Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody
  Env
env
  ( (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
_tk_div_ty, SubExp
tx_rx),
    SegLevel
segthd_lvl,
    [Int]
var_dims,
    (VName
gtid_x, SubExp
width_B, VName
gtid_y, SubExp
height_A, SubExp
common_dim),
    (SubExp
a_loc_sz, SubExp
b_loc_sz),
    (VName
iii, VName
jjj),
    (Stm GPU
load_A, VName
inp_A, PrimType
pt_A, Stm GPU
load_B, VName
inp_B, PrimType
pt_B),
    (Lambda GPU
map_lam, Lambda GPU
red_lam)
    )
  VName
kk0
  (VName
thd_res_merge, VName
a_loc_init', VName
b_loc_init')
  Bool
epilogue = do
    let (PrimType
map_t1, PrimType
map_t2) = (PrimType
pt_A, PrimType
pt_B)
    VName
kk <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"kk" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
kk0 forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk)
    -- copy A to local memory
    (VName
a_loc, VName -> VName -> BuilderT GPU (State VNameSource) VName
aCopyLoc2Reg) <-
      VName
-> (VName, VName, PrimType, SubExp, VName, Stm GPU, SubExp, VName)
-> Builder
     GPU
     (VName, VName -> VName -> BuilderT GPU (State VNameSource) VName)
copyGlb2ShMem VName
kk (VName
gtid_y, VName
iii, PrimType
map_t1, SubExp
height_A, VName
inp_A, Stm GPU
load_A, SubExp
a_loc_sz, VName
a_loc_init')

    -- copy B from global to shared memory
    (VName
b_loc, VName -> VName -> BuilderT GPU (State VNameSource) VName
bCopyLoc2Reg) <-
      VName
-> (VName, VName, PrimType, SubExp, VName, Stm GPU, SubExp, VName)
-> Builder
     GPU
     (VName, VName -> VName -> BuilderT GPU (State VNameSource) VName)
copyGlb2ShMem VName
kk (VName
gtid_x, VName
jjj, PrimType
map_t2, SubExp
width_B, VName
inp_B, Stm GPU
load_B, SubExp
b_loc_sz, VName
b_loc_init')

    -- inner loop updating this thread's accumulator (loop k in mmm_kernels).
    VName
thd_acc <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
tk [VName
thd_res_merge] forall a b. (a -> b) -> a -> b
$ \VName
k [VName
acc_merge] ->
      forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"foo"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          ( forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$
              if Bool
epilogue
                then forall a. a -> TPrimExp Int64 a
le64 VName
kk forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
k forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim
                else forall v. TPrimExp Bool v
true -- if in prologue, always compute redomap.
          )
          ( do
              [VName]
reg_mem <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D [Char]
"reg_mem" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) forall a b. (a -> b) -> a -> b
$
                \(VName
ltid_y, VName
ltid_x) -> do
                  -- copy A from local memory to registers
                  VName
asss <- VName -> VName -> BuilderT GPU (State VNameSource) VName
aCopyLoc2Reg VName
k VName
ltid_y
                  -- copy B from local memory to registers
                  VName
bsss <- VName -> VName -> BuilderT GPU (State VNameSource) VName
bCopyLoc2Reg VName
k VName
ltid_x
                  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName
asss, VName
bsss]
              let [VName
asss, VName
bsss] = [VName]
reg_mem
              VName -> VName -> VName -> Bool -> Builder GPU (Body GPU)
mkRedomapOneTileBody VName
acc_merge VName
asss VName
bsss Bool
True
          )
          (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
acc_merge])
    forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
thd_acc, VName
a_loc, VName
b_loc]
    where
      mk_ik :: Bool
-> (VName, VName)
-> (VName, VName)
-> BuilderT
     GPU (State VNameSource) (VName, VName, TPrimExp Int64 VName)
mk_ik Bool
is_coal (VName
thd_y, VName
thd_x) (VName
i0, VName
k0)
        | Bool
is_coal = do
            -- not-transposed case (i.e., already coalesced)
            let (SubExp
t_par, SubExp
t_seq) = (SubExp
tx, SubExp
tk)
            VName
k <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"k" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
thd_x forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
k0 forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_par)
            VName
i <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"i" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
thd_y forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i0 forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_par)
            -- we have padded to minimize bank conflicts,
            -- hence the length of inner dim is (t_seq + 1)
            let e :: TPrimExp Int64 VName
e = forall a. a -> TPrimExp Int64 a
le64 VName
k forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
* (SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_seq forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
se1)
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
i, VName
k, TPrimExp Int64 VName
e)
      mk_ik Bool
_ (VName
thd_y, VName
thd_x) (VName
i0, VName
k0) = do
        -- matrix is transposed case (i.e., uncoalesced):
        let (SubExp
t_par, SubExp
tr_par) = (SubExp
tx, SubExp
tx_rx)
        VName
k <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"k" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
thd_y forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
k0 forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_par)
        VName
i <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"i" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
thd_x forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i0 forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_par)
        -- we have padded to minimize bank conflicts,
        -- hence the length of inner dim is (tr_par + 1)
        let e :: TPrimExp Int64 VName
e = forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
k forall a. Num a => a -> a -> a
* (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tr_par forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
se1)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
i, VName
k, TPrimExp Int64 VName
e)
      isInnerCoal :: Env -> VName -> Stm GPU -> Bool
      isInnerCoal :: Env -> VName -> Stm GPU -> Bool
isInnerCoal (WithEnv
_, IxFnEnv
ixfn_env) VName
slc_X (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ (BasicOp (Index VName
x Slice SubExp
_)))
        | [VName
slc_X'] <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPU)
pat,
          VName
slc_X forall a. Eq a => a -> a -> Bool
== VName
slc_X',
          Maybe IxFun
Nothing <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x IxFnEnv
ixfn_env =
            Bool
True -- if not in the table, we assume not-transposed!
      isInnerCoal (WithEnv
_, IxFnEnv
ixfn_env) VName
slc_X (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ (BasicOp (Index VName
x Slice SubExp
_)))
        | [VName
slc_X'] <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPU)
pat,
          VName
slc_X forall a. Eq a => a -> a -> Bool
== VName
slc_X',
          Just IxFun
ixf_fn <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x IxFnEnv
ixfn_env,
          (IxFun.IxFun (LMAD (TPrimExp Int64 VName)
lmad :| []) Shape (TPrimExp Int64 VName)
_ Bool
_) <- IxFun
ixf_fn =
            let lmad_dims :: [LMADDim (TPrimExp Int64 VName)]
lmad_dims = forall num. LMAD num -> [LMADDim num]
IxFun.lmadDims LMAD (TPrimExp Int64 VName)
lmad
                q :: Int
q = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim (TPrimExp Int64 VName)]
lmad_dims
                last_perm :: Int
last_perm = forall num. LMADDim num -> Int
IxFun.ldPerm forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [LMADDim (TPrimExp Int64 VName)]
lmad_dims
                stride :: TPrimExp Int64 VName
stride = forall num. LMADDim num -> num
IxFun.ldStride forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [LMADDim (TPrimExp Int64 VName)]
lmad_dims
                res :: Bool
res = Int
last_perm forall a. Eq a => a -> a -> Bool
== Int
q forall a. Num a => a -> a -> a
- Int
1 Bool -> Bool -> Bool
&& (TPrimExp Int64 VName
stride forall a. Eq a => a -> a -> Bool
== SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1))
             in Bool
res
      isInnerCoal Env
_ VName
_ Stm GPU
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"TileLoops/Shared.hs: not an error, but I would like to know why!"
      --
      mkRedomapOneTileBody :: VName -> VName -> VName -> Bool -> Builder GPU (Body GPU)
mkRedomapOneTileBody VName
acc_merge VName
asss VName
bsss Bool
fits_ij = do
        -- the actual redomap.
        [VName]
redomap_res <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D [Char]
"redomap_res" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) forall a b. (a -> b) -> a -> b
$
          \(VName
ltid_y, VName
ltid_x) -> do
            VName
as <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"as" VName
asss [VName
ltid_y, VName
ltid_x]
            VName
bs <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"bs" VName
bsss [VName
ltid_y, VName
ltid_x]
            VName
css_init <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"css_init" VName
acc_merge [VName
ltid_y, VName
ltid_x]

            VName
css <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
ry [VName
css_init] forall a b. (a -> b) -> a -> b
$ \VName
i [VName
css_merge] -> do
              VName
css <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rx [VName
css_merge] forall a b. (a -> b) -> a -> b
$ \VName
j [VName
css_merge'] ->
                forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM
                  forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"foo"
                  forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                    ( forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$
                        if Bool
fits_ij
                          then forall v. TPrimExp Bool v
true
                          else -- this condition is never needed because
                          -- if i and j are out of range than css[i,j]
                          -- is garbage anyways and should not be written.
                          -- so fits_ij should be always true!!!

                            forall a. a -> TPrimExp Int64 a
le64 VName
iii forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y
                              forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
height_A
                              forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
le64 VName
jjj forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
j forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x
                                forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
width_B
                    )
                    ( do
                        VName
a <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"a" VName
as [VName
i]
                        VName
b <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"b" VName
bs [VName
j]
                        VName
c <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"c" VName
css_merge' [VName
i, VName
j]

                        Lambda GPU
map_lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam
                        Lambda GPU
red_lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
red_lam

                        -- the inputs to map are supposed to be permutted with the
                        -- inverted permutation, so as to reach the original position;
                        -- it just so happens that the inverse of [a,b] is [b,a]
                        let map_inp_reg :: [VName]
map_inp_reg = if [Int]
var_dims forall a. Eq a => a -> a -> Bool
== [Int
0, Int
1] then [VName
a, VName
b] else [VName
b, VName
a]

                        Result
map_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda GPU
map_lam' (forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
map_inp_reg)
                        ~[SubExpRes
red_res] <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda GPU
red_lam' (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
c forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
map_res)
                        VName
css <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update [Char]
"css" VName
css_merge' [VName
i, VName
j] (SubExpRes -> SubExp
resSubExp SubExpRes
red_res)

                        forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css]
                    )
                    (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css_merge'])
              forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css]
            forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
css]
        forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
redomap_res
      --
      copyGlb2ShMem ::
        VName ->
        (VName, VName, PrimType, SubExp, VName, Stm GPU, SubExp, VName) ->
        Builder GPU (VName, VName -> VName -> Builder GPU VName)
      copyGlb2ShMem :: VName
-> (VName, VName, PrimType, SubExp, VName, Stm GPU, SubExp, VName)
-> Builder
     GPU
     (VName, VName -> VName -> BuilderT GPU (State VNameSource) VName)
copyGlb2ShMem VName
kk (VName
gtid, VName
ii, PrimType
ptp_X_el, SubExp
parlen_X, VName
inp_X, Stm GPU
load_X, SubExp
loc_sz_X, VName
x_loc_init') = do
        let (SubExp
t_par, SubExp
r_par, SubExp
tseq_div_tpar) = (SubExp
tx, SubExp
rx, SubExp
tk_div_tx)
            is_inner_coal :: Bool
is_inner_coal = Env -> VName -> Stm GPU -> Bool
isInnerCoal Env
env VName
inp_X Stm GPU
load_X
            str_A :: [Char]
str_A = VName -> [Char]
baseString VName
inp_X
        VName
x_loc <-
          [Char]
-> SubExp
-> VName
-> [SubExp]
-> (SubExp, SubExp)
-> ([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp))
-> BuilderT GPU (State VNameSource) VName
segScatter2D ([Char]
str_A forall a. [a] -> [a] -> [a]
++ [Char]
"_glb2loc") SubExp
loc_sz_X VName
x_loc_init' [SubExp
r_par, SubExp
tseq_div_tpar] (SubExp
t_par, SubExp
t_par) forall a b. (a -> b) -> a -> b
$
            Bool -> [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
scatterFun Bool
is_inner_coal

        forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
x_loc, Bool
-> [Char]
-> VName
-> VName
-> VName
-> BuilderT GPU (State VNameSource) VName
copyLoc2Reg Bool
is_inner_coal [Char]
str_A VName
x_loc)
        where
          copyLoc2Reg ::
            Bool ->
            String ->
            VName ->
            VName ->
            VName ->
            Builder GPU VName
          copyLoc2Reg :: Bool
-> [Char]
-> VName
-> VName
-> VName
-> BuilderT GPU (State VNameSource) VName
copyLoc2Reg Bool
is_inner_coal [Char]
str_A VName
x_loc VName
k VName
ltid_yx = do
            let (SubExp
r_par, SubExp
t_seq, SubExp
tr_par) = (SubExp
rx, SubExp
tk, SubExp
tx_rx)
            VName
xsss_init <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch ([Char]
str_A forall a. [a] -> [a] -> [a]
++ [Char]
"_init_regs") PrimType
ptp_X_el [SubExp
r_par]
            SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
r_par [VName
xsss_init] forall a b. (a -> b) -> a -> b
$ \VName
ij [VName
xsss_merge] -> do
              VName
x_loc_ind <-
                forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp ([Char]
str_A forall a. [a] -> [a] -> [a]
++ [Char]
"_loc_ind")
                  forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp
                    ( if Bool
is_inner_coal
                        then forall a. a -> TPrimExp Int64 a
le64 VName
k forall a. Num a => a -> a -> a
+ (forall a. a -> TPrimExp Int64 a
le64 VName
ltid_yx forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
r_par forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ij) forall a. Num a => a -> a -> a
* (SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_seq forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
se1)
                        else forall a. a -> TPrimExp Int64 a
le64 VName
ij forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_yx forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
r_par forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
k forall a. Num a => a -> a -> a
* (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tr_par forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
se1)
                    )
              VName
xsss <-
                forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update ([Char]
str_A forall a. [a] -> [a] -> [a]
++ [Char]
"_regs") VName
xsss_merge [VName
ij] forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
                  forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index ([Char]
str_A forall a. [a] -> [a] -> [a]
++ [Char]
"_loc_elem") VName
x_loc [VName
x_loc_ind]
              forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
xsss]
          --
          scatterFun ::
            Bool ->
            [VName] ->
            (VName, VName) ->
            Builder GPU (SubExp, SubExp)
          scatterFun :: Bool -> [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
scatterFun Bool
is_inner_coal [VName
i0, VName
k0] (VName
thd_y, VName
thd_x) = do
            let str_A :: [Char]
str_A = VName -> [Char]
baseString VName
inp_X
                t_seq :: SubExp
t_seq = SubExp
tk
            (VName
i, VName
k, TPrimExp Int64 VName
epx_loc_fi) <- Bool
-> (VName, VName)
-> (VName, VName)
-> BuilderT
     GPU (State VNameSource) (VName, VName, TPrimExp Int64 VName)
mk_ik Bool
is_inner_coal (VName
thd_y, VName
thd_x) (VName
i0, VName
k0)
            forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
ii forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i)
            VName
a_seqdim_idx <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp ([Char]
str_A forall a. [a] -> [a] -> [a]
++ [Char]
"_seqdim_idx") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
kk forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
k)

            SubExp
a_elem <-
              forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp ([Char]
str_A forall a. [a] -> [a] -> [a]
++ [Char]
"_elem")
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$
                      forall a. a -> TPrimExp Int64 a
le64 VName
gtid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
parlen_X
                        forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. if Bool
epilogue
                          then forall a. a -> TPrimExp Int64 a
le64 VName
a_seqdim_idx forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim
                          else forall v. TPrimExp Bool v
true
                  )
                  ( do
                      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm GPU
load_X
                      VName
res <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"A_elem" VName
inp_X [VName
a_seqdim_idx]
                      forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
res]
                  )
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *).
MonadBuilder m =>
TypeBase Shape NoUniqueness -> m (Exp (Rep m))
eBlank forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ptp_X_el])

            SubExp
a_loc_ind <-
              forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp ([Char]
str_A forall a. [a] -> [a] -> [a]
++ [Char]
"_loc_ind")
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
k forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_seq)
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
epx_loc_fi])
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)])
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
a_elem, SubExp
a_loc_ind)
          scatterFun Bool
_ [VName]
_ (VName, VName)
_ = do
            forall a. HasCallStack => [Char] -> a
error [Char]
"Function scatterFun in Shared.hs: 2nd arg should be an array with 2 elements!"

-- ToDo: we need tx == ty (named t_par), and rx == ry (named r_par)
--       in order to handle all the cases without transpositions.
--       additionally, of course, we need that tk is a multiple of t_par.
mmBlkRegTiling :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTiling :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTiling Env
env Stm GPU
stm = do
  Maybe (Stms GPU, Stm GPU)
res <- Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingAcc Env
env Stm GPU
stm
  case Maybe (Stms GPU, Stm GPU)
res of
    Maybe (Stms GPU, Stm GPU)
Nothing -> Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingNrm Env
env Stm GPU
stm
    Maybe (Stms GPU, Stm GPU)
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
res

mmBlkRegTilingAcc :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingAcc :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingAcc Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegThread {} SegSpace
seg_space [TypeBase Shape NoUniqueness]
ts KernelBody GPU
old_kbody))))
  | KernelBody () Stms GPU
kstms [Returns ResultManifest
ResultMaySimplify Certs
cs (Var VName
res_nm)] <- KernelBody GPU
old_kbody,
    Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
    -- check kernel has one result of primitive type
    [TypeBase Shape NoUniqueness
res_tp] <- [TypeBase Shape NoUniqueness]
ts,
    forall shape u. TypeBase shape u -> Bool
isAcc TypeBase Shape NoUniqueness
res_tp,
    -- we get the global-thread id for the two inner dimensions,
    --   as we are probably going to use it in code generation
    (VName
gtid_x, SubExp
width_B) : (VName
gtid_y, SubExp
height_A) : [(VName, SubExp)]
rem_outer_dims_rev <-
      forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
seg_space,
    [(VName, SubExp)]
rem_outer_dims <- forall a. [a] -> [a]
reverse [(VName, SubExp)]
rem_outer_dims_rev,
    Just
      ( Stms GPU
code2',
        (Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
        SubExp
common_dim,
        [Int]
var_dims,
        (Lambda GPU
map_lam, Lambda GPU
red_lam, SubExp
red_ne, VName
redomap_orig_res, PrimType
red_t)
        ) <-
      SegSpace
-> Stms GPU
-> Maybe
     (Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
      SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType))
matchesBlkRegTile SegSpace
seg_space Stms GPU
kstms,
    forall {rep}. VName -> Stms rep -> VName -> Bool
checkAccumulatesRedomapRes VName
res_nm Stms GPU
code2' VName
redomap_orig_res = do
      -- Here we start the implementation --
      ---- in this binder: host code and outer seggroup (ie. the new kernel) ----
      (Stm GPU
new_kernel, Stms GPU
host_stms) <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
        -- host code
        (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx, SubExp
ty_ry, SubExp
a_loc_sz, SubExp
b_loc_sz) <-
          SubExp
-> SubExp
-> SubExp
-> Builder
     GPU
     (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
      SubExp, SubExp, SubExp)
mkTileMemSizes SubExp
height_A SubExp
width_B SubExp
common_dim

        SubExp
rk <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rk" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
8 -- 16 and 8 seem good values
        SubExp
tk_rk <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tk_rk" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rk)

        SubExp
gridDim_t <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_t" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
common_dim SubExp
tk_rk
        SubExp
gridDim_y <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_y" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
height_A SubExp
ty_ry
        SubExp
gridDim_x <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
width_B SubExp
tx_rx

        let gridxyt_pexp :: TPrimExp Int64 VName
gridxyt_pexp = SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_y forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_x forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_t
            grid_pexp :: TPrimExp Int64 VName
grid_pexp =
              forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 VName
x SubExp
d -> SubExp -> TPrimExp Int64 VName
pe64 SubExp
d forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
x) TPrimExp Int64 VName
gridxyt_pexp forall a b. (a -> b) -> a -> b
$
                forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
rem_outer_dims_rev

        (SubExp
grid_size, SubExp
group_size, SegLevel
segthd_lvl) <- SubExp
-> SubExp
-> TPrimExp Int64 VName
-> Builder GPU (SubExp, SubExp, SegLevel)
mkNewSegthdLvl SubExp
tx SubExp
ty TPrimExp Int64 VName
grid_pexp
        (VName
gid_x, VName
gid_y, VName
gid_flat) <- Builder GPU (VName, VName, VName)
mkGidsXYF
        VName
gid_t <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_t"

        ---- in this binder: outer seggroup ----
        ([KernelResult]
ret_seggroup, Stms GPU
stms_seggroup) <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
          VName
iii <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iii" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_y forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty_ry)
          VName
jjj <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jjj" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_x forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx_rx)
          VName
ttt <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ttt" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_t forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk_rk)

          -- initialize register mem with neutral elements and create shmem
          (VName
cssss, VName
a_loc_init, VName
b_loc_init) <-
            (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp)
-> (PrimType, PrimType, PrimType)
-> SegLevel
-> SubExp
-> Builder GPU (VName, VName, VName)
initRegShmem
              (SubExp
rx, SubExp
tx, SubExp
ry, SubExp
ty, SubExp
a_loc_sz, SubExp
b_loc_sz)
              (PrimType
map_t1, PrimType
map_t2, PrimType
red_t)
              SegLevel
segthd_lvl
              SubExp
red_ne

          -- build prologue.
          SubExp
elems_on_t <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"elems_on_t" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
ttt)
          SubExp
tiles_on_t <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tiles_on_t" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
elems_on_t SubExp
tk
          VName
full_tiles <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"full_tiles" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int64) SubExp
rk SubExp
tiles_on_t

          let ct_arg :: ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
 SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
 (SubExp, SubExp), (VName, VName),
 (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
 (Lambda GPU, Lambda GPU))
ct_arg =
                ( (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx),
                  SegLevel
segthd_lvl,
                  [Int]
var_dims,
                  (VName
gtid_x, SubExp
width_B, VName
gtid_y, SubExp
height_A, SubExp
common_dim),
                  (SubExp
a_loc_sz, SubExp
b_loc_sz),
                  (VName
iii, VName
jjj),
                  (Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
                  (Lambda GPU
map_lam, Lambda GPU
red_lam)
                )

          [VName]
prologue_res_list <-
            SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' (VName -> SubExp
Var VName
full_tiles) [VName
cssss, VName
a_loc_init, VName
b_loc_init] forall a b. (a -> b) -> a -> b
$
              \VName
kk0 [VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge] -> do
                VName
off_t <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"off_t" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
rk forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
gid_t forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
kk0)
                [VName]
process_full_tiles <-
                  Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
     SubExp),
    SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
    (SubExp, SubExp), (VName, VName),
    (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
    (Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody Env
env ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
 SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
 (SubExp, SubExp), (VName, VName),
 (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
 (Lambda GPU, Lambda GPU))
ct_arg VName
off_t (VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge) Bool
False

                forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
process_full_tiles

          let VName
prologue_res : VName
a_loc_reuse : VName
b_loc_reuse : [VName]
_ = [VName]
prologue_res_list

          [VName]
redomap_res_lst <-
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"redomap_res_if"
              forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                ( forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$
                    forall a. a -> TPrimExp Int64 a
le64 VName
full_tiles forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. SubExp -> TPrimExp Int64 VName
pe64 SubExp
rk
                      forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
full_tiles forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ttt)
                )
                (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
prologue_res_list)
                ( do
                    VName
off_t <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"off_t" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
rk forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
gid_t forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
full_tiles)
                    [VName]
process_sprs_tile <-
                      Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
     SubExp),
    SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
    (SubExp, SubExp), (VName, VName),
    (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
    (Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody Env
env ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
 SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
 (SubExp, SubExp), (VName, VName),
 (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
 (Lambda GPU, Lambda GPU))
ct_arg VName
off_t (VName
prologue_res, VName
a_loc_reuse, VName
b_loc_reuse) Bool
True

                    forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
process_sprs_tile
                )
          let VName
redomap_res : [VName]
_ = [VName]
redomap_res_lst

          -- support for non-empty code2'
          --  segmap (ltid_y < ty, ltid_x < tx) {
          --    for i < ry do
          --      for j < rx do
          --        res = if (iii+ltid_y*ry+i < height_A && jjj+ltid_x*rx+j < width_B)
          --              then code2' else dummy
          --        final_res[i,j] = res
          SegLevel
-> (VName, VName)
-> (VName, TypeBase Shape NoUniqueness)
-> (SubExp, SubExp, SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp, [(VName, SubExp)])
-> Stms GPU
-> Builder GPU [KernelResult]
mkEpilogueAccRes
            SegLevel
segthd_lvl
            (VName
redomap_orig_res, VName
redomap_res)
            (VName
res_nm, TypeBase Shape NoUniqueness
res_tp)
            (SubExp
ty, SubExp
tx, SubExp
ry, SubExp
rx)
            (VName
iii, VName
jjj)
            (VName
gtid_y, VName
gtid_x)
            (SubExp
height_A, SubExp
width_B, [(VName, SubExp)]
rem_outer_dims)
            Stms GPU
code2'

        let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (forall {k} (u :: k) e. e -> Count u e
Count SubExp
grid_size) (forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size)
            level' :: SegLevel
level' = SegVirt -> Maybe KernelGrid -> SegLevel
SegGroup SegVirt
SegNoVirt (forall a. a -> Maybe a
Just KernelGrid
grid)
            space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)]
rem_outer_dims forall a. [a] -> [a] -> [a]
++ [(VName
gid_t, SubExp
gridDim_t), (VName
gid_y, SubExp
gridDim_y), (VName
gid_x, SubExp
gridDim_x)])
            kbody' :: KernelBody GPU
kbody' = forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms_seggroup [KernelResult]
ret_seggroup
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
level' SegSpace
space' [TypeBase Shape NoUniqueness]
ts KernelBody GPU
kbody'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Stms GPU
host_stms, Stm GPU
new_kernel)
  where
    sameAccType :: VName -> TypeBase shape u -> Bool
sameAccType VName
acc_sglton (Acc VName
sglton Shape
_ [TypeBase Shape NoUniqueness]
_ u
_) =
      VName
acc_sglton forall a. Eq a => a -> a -> Bool
== VName
sglton
    sameAccType VName
_ TypeBase shape u
_ = Bool
False
    getAccumFV :: TypeBase Shape NoUniqueness
-> BuilderT GPU (State VNameSource) VName
getAccumFV (Acc VName
singleton Shape
_shp [TypeBase Shape NoUniqueness
_eltp] NoUniqueness
_) = do
      let fvs :: [VName]
fvs = Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn KernelBody GPU
old_kbody -- code
      [TypeBase Shape NoUniqueness]
tps <- forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
seg_space) forall a b. (a -> b) -> a -> b
$ do
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
fvs
      let ([VName]
acc_0s, [TypeBase Shape NoUniqueness]
_) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall {shape} {u}. VName -> TypeBase shape u -> Bool
sameAccType VName
singleton forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
fvs [TypeBase Shape NoUniqueness]
tps
      case [VName]
acc_0s of
        [VName
acc_0] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
acc_0
        [VName]
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible case reached when treating accumulators!"
    getAccumFV TypeBase Shape NoUniqueness
tp = forall a. HasCallStack => [Char] -> a
error ([Char]
"Should be an accumulator type at this point, given: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
tp)
    --
    -- checks that the redomap result is used directly as the accumulated value,
    -- in which case it is safe to parallelize the innermost dimension (of tile tk)
    checkAccumulatesRedomapRes :: VName -> Stms rep -> VName -> Bool
checkAccumulatesRedomapRes VName
res_nm Stms rep
acc_code VName
redomap_orig_res = do
      forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Bool -> Stm rep -> Bool
getAccumStm Bool
False forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
acc_code
      where
        getAccumStm :: Bool -> Stm rep -> Bool
getAccumStm Bool
True Stm rep
_ = Bool
True
        getAccumStm Bool
False (Let (Pat [PatElem (LetDec rep)
pat_el]) StmAux (ExpDec rep)
_aux (BasicOp (UpdateAcc VName
_acc_nm [SubExp]
_ind [SubExp]
vals)))
          | [SubExp
v] <- [SubExp]
vals,
            forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pat_el forall a. Eq a => a -> a -> Bool
== VName
res_nm =
              SubExp
v forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
redomap_orig_res
        getAccumStm Bool
False Stm rep
_ = Bool
False
    --
    -- epilogue for accumulator result type
    mkEpilogueAccRes :: SegLevel
-> (VName, VName)
-> (VName, TypeBase Shape NoUniqueness)
-> (SubExp, SubExp, SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp, [(VName, SubExp)])
-> Stms GPU
-> Builder GPU [KernelResult]
mkEpilogueAccRes
      SegLevel
segthd_lvl
      (VName
redomap_orig_res, VName
redomap_res)
      (VName
res_nm, TypeBase Shape NoUniqueness
res_tp)
      (SubExp
ty, SubExp
tx, SubExp
ry, SubExp
rx)
      (VName
iii, VName
jjj)
      (VName
gtid_y, VName
gtid_x)
      (SubExp
height_A, SubExp
width_B, [(VName, SubExp)]
_rem_outer_dims)
      Stms GPU
code2' = do
        VName
rss_init <- TypeBase Shape NoUniqueness
-> BuilderT GPU (State VNameSource) VName
getAccumFV TypeBase Shape NoUniqueness
res_tp
        [VName]
rssss_list <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D [Char]
"rssss" SegLevel
segthd_lvl ResultManifest
ResultMaySimplify (SubExp
ty, SubExp
tx) forall a b. (a -> b) -> a -> b
$ \(VName
ltid_y, VName
ltid_x) -> do
          (VName
css, VName
ii, VName
jj) <- (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName, VName)
-> Builder GPU (VName, VName, VName)
getThdRedomapRes (SubExp
rx, SubExp
ry) (VName
ltid_x, VName
ltid_y) (VName
iii, VName
jjj, VName
redomap_res)
          VName
rss <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
ry [VName
rss_init] forall a b. (a -> b) -> a -> b
$ \VName
i [VName
rss_merge] -> do
            VName
rss' <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rx [VName
rss_merge] forall a b. (a -> b) -> a -> b
$ \VName
j [VName
rss_merge'] -> do
              (VName, VName)
-> (VName, VName, VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
prereqAddCode2 (VName
gtid_x, VName
gtid_y) (VName
ii, VName
i, VName
jj, VName
j) (VName
css, VName
redomap_orig_res)
              let code2_subs :: Stms GPU
code2_subs = forall a. Substitute a => Map VName VName -> a -> a
substituteNames (forall k a. k -> a -> Map k a
M.singleton VName
rss_init VName
rss_merge') Stms GPU
code2'

              SubExp
res_el <-
                forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"res_elem"
                  forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                    ( forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$
                        forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
height_A
                          forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
width_B
                    )
                    ( do
                        forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
code2_subs
                        forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
res_nm]
                    )
                    (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
rss_merge'])
              forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
res_el]
            forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
rss']
          forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
rss]
        let VName
epilogue_res_acc : [VName]
_ = [VName]
rssss_list
        forall (f :: * -> *) a. Applicative f => a -> f a
pure [ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify ([VName] -> Certs
Certs []) forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
epilogue_res_acc]
mmBlkRegTilingAcc Env
_ Stm GPU
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

--------------------------
--------------------------

mmBlkRegTilingNrm :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingNrm :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingNrm Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegThread {} SegSpace
seg_space [TypeBase Shape NoUniqueness]
ts KernelBody GPU
old_kbody))))
  | KernelBody () Stms GPU
kstms [Returns ResultManifest
ResultMaySimplify Certs
cs (Var VName
res_nm)] <- KernelBody GPU
old_kbody,
    Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
    -- check kernel has one result of primitive type
    [TypeBase Shape NoUniqueness
res_tp] <- [TypeBase Shape NoUniqueness]
ts,
    forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
res_tp,
    -- we get the global-thread id for the two inner dimensions,
    --   as we are probably going to use it in code generation
    (VName
gtid_x, SubExp
width_B) : (VName
gtid_y, SubExp
height_A) : [(VName, SubExp)]
rem_outer_dims_rev <-
      forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
seg_space,
    [(VName, SubExp)]
rem_outer_dims <- forall a. [a] -> [a]
reverse [(VName, SubExp)]
rem_outer_dims_rev,
    Just
      ( Stms GPU
code2',
        (Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
        SubExp
common_dim,
        [Int]
var_dims,
        (Lambda GPU
map_lam, Lambda GPU
red_lam, SubExp
red_ne, VName
redomap_orig_res, PrimType
red_t)
        ) <-
      SegSpace
-> Stms GPU
-> Maybe
     (Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
      SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType))
matchesBlkRegTile SegSpace
seg_space Stms GPU
kstms = do
      -- Here we start the implementation
      ---- in this binder: host code and outer seggroup (ie. the new kernel) ----
      (Stm GPU
new_kernel, Stms GPU
host_stms) <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
        -- host code
        (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx, SubExp
ty_ry, SubExp
a_loc_sz, SubExp
b_loc_sz) <-
          SubExp
-> SubExp
-> SubExp
-> Builder
     GPU
     (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
      SubExp, SubExp, SubExp)
mkTileMemSizes SubExp
height_A SubExp
width_B SubExp
common_dim

        SubExp
gridDim_x <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
width_B SubExp
tx_rx
        SubExp
gridDim_y <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_y" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
height_A SubExp
ty_ry
        let gridxy_pexp :: TPrimExp Int64 VName
gridxy_pexp = SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_y forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_x
        let grid_pexp :: TPrimExp Int64 VName
grid_pexp =
              forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 VName
x SubExp
d -> SubExp -> TPrimExp Int64 VName
pe64 SubExp
d forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
x) TPrimExp Int64 VName
gridxy_pexp forall a b. (a -> b) -> a -> b
$
                forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
rem_outer_dims_rev
        (SubExp
grid_size, SubExp
group_size, SegLevel
segthd_lvl) <- SubExp
-> SubExp
-> TPrimExp Int64 VName
-> Builder GPU (SubExp, SubExp, SegLevel)
mkNewSegthdLvl SubExp
tx SubExp
ty TPrimExp Int64 VName
grid_pexp

        (VName
gid_x, VName
gid_y, VName
gid_flat) <- Builder GPU (VName, VName, VName)
mkGidsXYF

        ---- in this binder: outer seggroup ----
        ([KernelResult]
ret_seggroup, Stms GPU
stms_seggroup) <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
          VName
iii <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iii" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_y forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty_ry)
          VName
jjj <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jjj" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_x forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx_rx)

          -- initialize register mem with neutral elements and create shmem
          (VName
cssss, VName
a_loc_init, VName
b_loc_init) <-
            (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp)
-> (PrimType, PrimType, PrimType)
-> SegLevel
-> SubExp
-> Builder GPU (VName, VName, VName)
initRegShmem
              (SubExp
rx, SubExp
tx, SubExp
ry, SubExp
ty, SubExp
a_loc_sz, SubExp
b_loc_sz)
              (PrimType
map_t1, PrimType
map_t2, PrimType
red_t)
              SegLevel
segthd_lvl
              SubExp
red_ne

          -- build prologue.
          VName
full_tiles <-
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"full_tiles" forall a b. (a -> b) -> a -> b
$
              forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
common_dim SubExp
tk

          let ct_arg :: ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
 SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
 (SubExp, SubExp), (VName, VName),
 (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
 (Lambda GPU, Lambda GPU))
ct_arg =
                ( (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx),
                  SegLevel
segthd_lvl,
                  [Int]
var_dims,
                  (VName
gtid_x, SubExp
width_B, VName
gtid_y, SubExp
height_A, SubExp
common_dim),
                  (SubExp
a_loc_sz, SubExp
b_loc_sz),
                  (VName
iii, VName
jjj),
                  (Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
                  (Lambda GPU
map_lam, Lambda GPU
red_lam)
                )

          [VName]
prologue_res_list <-
            SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' (VName -> SubExp
Var VName
full_tiles) [VName
cssss, VName
a_loc_init, VName
b_loc_init] forall a b. (a -> b) -> a -> b
$
              \VName
kk0 [VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge] -> do
                [VName]
process_full_tiles <-
                  Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
     SubExp),
    SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
    (SubExp, SubExp), (VName, VName),
    (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
    (Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody Env
env ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
 SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
 (SubExp, SubExp), (VName, VName),
 (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
 (Lambda GPU, Lambda GPU))
ct_arg VName
kk0 (VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge) Bool
False

                forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
process_full_tiles

          let VName
prologue_res : VName
a_loc_reuse : VName
b_loc_reuse : [VName]
_ = [VName]
prologue_res_list

          -- build epilogue.
          [VName]
epilogue_res_list <- Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
     SubExp),
    SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
    (SubExp, SubExp), (VName, VName),
    (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
    (Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody Env
env ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
 SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
 (SubExp, SubExp), (VName, VName),
 (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
 (Lambda GPU, Lambda GPU))
ct_arg VName
full_tiles (VName
prologue_res, VName
a_loc_reuse, VName
b_loc_reuse) Bool
True

          let VName
redomap_res : [VName]
_ = [VName]
epilogue_res_list

          -- support for non-empty code2'
          --  segmap (ltid_y < ty, ltid_x < tx) {
          --    for i < ry do
          --      for j < rx do
          --        res = if (iii+ltid_y*ry+i < height_A && jjj+ltid_x*rx+j < width_B)
          --              then code2' else dummy
          --        final_res[i,j] = res
          forall {a}.
SegLevel
-> (VName, VName)
-> (VName, TypeBase Shape NoUniqueness)
-> (SubExp, SubExp, SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp, [(a, SubExp)])
-> Stms GPU
-> Builder GPU [KernelResult]
mkEpiloguePrimRes
            SegLevel
segthd_lvl
            (VName
redomap_orig_res, VName
redomap_res)
            (VName
res_nm, TypeBase Shape NoUniqueness
res_tp)
            (SubExp
ty, SubExp
tx, SubExp
ry, SubExp
rx)
            (VName
iii, VName
jjj)
            (VName
gtid_y, VName
gtid_x)
            (SubExp
height_A, SubExp
width_B, [(VName, SubExp)]
rem_outer_dims)
            Stms GPU
code2'

        let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (forall {k} (u :: k) e. e -> Count u e
Count SubExp
grid_size) (forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size)
            level' :: SegLevel
level' = SegVirt -> Maybe KernelGrid -> SegLevel
SegGroup SegVirt
SegNoVirt (forall a. a -> Maybe a
Just KernelGrid
grid)
            space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)]
rem_outer_dims forall a. [a] -> [a] -> [a]
++ [(VName
gid_y, SubExp
gridDim_y), (VName
gid_x, SubExp
gridDim_x)])
            kbody' :: KernelBody GPU
kbody' = forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms_seggroup [KernelResult]
ret_seggroup
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
level' SegSpace
space' [TypeBase Shape NoUniqueness]
ts KernelBody GPU
kbody'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Stms GPU
host_stms, Stm GPU
new_kernel)
  where
    mkEpiloguePrimRes :: SegLevel
-> (VName, VName)
-> (VName, TypeBase Shape NoUniqueness)
-> (SubExp, SubExp, SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp, [(a, SubExp)])
-> Stms GPU
-> Builder GPU [KernelResult]
mkEpiloguePrimRes
      SegLevel
segthd_lvl
      (VName
redomap_orig_res, VName
redomap_res)
      (VName
res_nm, TypeBase Shape NoUniqueness
res_tp)
      (SubExp
ty, SubExp
tx, SubExp
ry, SubExp
rx)
      (VName
iii, VName
jjj)
      (VName
gtid_y, VName
gtid_x)
      (SubExp
height_A, SubExp
width_B, [(a, SubExp)]
rem_outer_dims)
      Stms GPU
code2' = do
        VName
epilogue_res <-
          if VName
redomap_orig_res forall a. Eq a => a -> a -> Bool
== VName
res_nm
            then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
redomap_res -- epilogue_res_list
            else do
              [VName]
rssss_list <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D [Char]
"rssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) forall a b. (a -> b) -> a -> b
$ \(VName
ltid_y, VName
ltid_x) -> do
                VName
rss_init <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"rss_init" (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
res_tp) [SubExp
ry, SubExp
rx]
                (VName
css, VName
ii, VName
jj) <- (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName, VName)
-> Builder GPU (VName, VName, VName)
getThdRedomapRes (SubExp
rx, SubExp
ry) (VName
ltid_x, VName
ltid_y) (VName
iii, VName
jjj, VName
redomap_res)
                VName
rss <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
ry [VName
rss_init] forall a b. (a -> b) -> a -> b
$ \VName
i [VName
rss_merge] -> do
                  VName
rss' <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rx [VName
rss_merge] forall a b. (a -> b) -> a -> b
$ \VName
j [VName
rss_merge'] -> do
                    (VName, VName)
-> (VName, VName, VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
prereqAddCode2 (VName
gtid_x, VName
gtid_y) (VName
ii, VName
i, VName
jj, VName
j) (VName
css, VName
redomap_orig_res)

                    SubExp
res_el <-
                      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"res_elem"
                        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                          ( forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$
                              forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
height_A
                                forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
width_B
                          )
                          ( do
                              forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
code2'
                              forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
res_nm]
                          )
                          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *).
MonadBuilder m =>
TypeBase Shape NoUniqueness -> m (Exp (Rep m))
eBlank TypeBase Shape NoUniqueness
res_tp])
                    VName
rss'' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update [Char]
"rss" VName
rss_merge' [VName
i, VName
j] SubExp
res_el
                    forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
rss'']
                  forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
rss']
                forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
rss]
              let VName
rssss : [VName]
_ = [VName]
rssss_list
              forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
rssss

        let regtile_ret_dims :: [(SubExp, SubExp, SubExp)]
regtile_ret_dims =
              forall a b. (a -> b) -> [a] -> [b]
map (\(a
_, SubExp
sz) -> (SubExp
sz, SubExp
se1, SubExp
se1)) [(a, SubExp)]
rem_outer_dims
                forall a. [a] -> [a] -> [a]
++ [(SubExp
height_A, SubExp
ty, SubExp
ry), (SubExp
width_B, SubExp
tx, SubExp
rx)]

        -- Add dummy dimensions to tile to reflect the outer dimensions.
        VName
epilogue_res' <-
          if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(a, SubExp)]
rem_outer_dims
            then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
epilogue_res
            else do
              TypeBase Shape NoUniqueness
epilogue_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
epilogue_res
              let ([SubExp]
block_dims, [SubExp]
rest_dims) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
2 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
epilogue_t
                  ones :: [SubExp]
ones = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [(a, SubExp)]
rem_outer_dims
                  new_shape :: Shape
new_shape = forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]
ones, [SubExp]
block_dims, [SubExp]
ones, [SubExp]
rest_dims]
              forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"res_reshaped" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary Shape
new_shape VName
epilogue_res
        forall (f :: * -> *) a. Applicative f => a -> f a
pure [Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns forall a. Monoid a => a
mempty [(SubExp, SubExp, SubExp)]
regtile_ret_dims VName
epilogue_res']
mmBlkRegTilingNrm Env
_ Stm GPU
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

-- pattern match the properties of the code that we look to
-- tile: a redomap whose two input arrays are each invariant
-- to one of the last two (innermost) parallel dimensions.
matchesBlkRegTile ::
  SegSpace ->
  Stms GPU ->
  Maybe
    ( Stms GPU,
      (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
      SubExp,
      [Int],
      (Lambda GPU, Lambda GPU, SubExp, VName, PrimType)
    )
matchesBlkRegTile :: SegSpace
-> Stms GPU
-> Maybe
     (Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
      SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType))
matchesBlkRegTile SegSpace
seg_space Stms GPU
kstms
  | -- build the variance table, that records, for
    -- each variable name, the variables it depends on
    Map VName Names
initial_variance <- forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
seg_space,
    Map VName Names
variance <- Map VName Names -> Stms GPU -> Map VName Names
varianceInStms Map VName Names
initial_variance Stms GPU
kstms,
    -- check that the code fits the pattern having:
    -- some `code1`, followed by one Screma SOAC, followed by some `code2`
    (Stms GPU
code1, Just Stm GPU
screma_stmt, Stms GPU
code2) <- Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeStreamCode Stms GPU
kstms,
    Let Pat (LetDec GPU)
pat_redomap StmAux (ExpDec GPU)
_ (Op Op GPU
_) <- Stm GPU
screma_stmt,
    -- checks that the Screma SOAC is actually a redomap and normalizes it
    Just (SubExp
common_dim, [VName]
arrs, (Commutativity
_, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam)) <- Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
screma_stmt,
    -- check that exactly two 1D arrays are streamed thorugh redomap,
    -- and the result of redomap is one scalar
    -- !!!I need to rearrange this whole thing!!! including inp_A and inp_B
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs forall a. Eq a => a -> a -> Bool
== Int
2,
    [SubExp
red_ne] <- [SubExp]
red_nes,
    [TypeBase Shape NoUniqueness
map_t1t, TypeBase Shape NoUniqueness
map_t2t] <- forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> dec
paramDec forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam,
    [TypeBase Shape NoUniqueness
red_t1, TypeBase Shape NoUniqueness
_] <- forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> dec
paramDec forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam,
    forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
map_t1t Bool -> Bool -> Bool
&& forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
map_t2t Bool -> Bool -> Bool
&& forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
red_t1,
    PrimType
map_t1_0 <- forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
map_t1t,
    PrimType
map_t2_0 <- forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
map_t2t,
    -- checks that the input arrays to redomap are variant to
    -- exactly one of the two innermost dimensions of the kernel
    Just [Int]
var_dims <- Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo1of2InnerDims forall a. Monoid a => a
mempty SegSpace
seg_space Map VName Names
variance [VName]
arrs,
    -- get the variables on which the first result of redomap depends on
    [VName
redomap_orig_res] <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPU)
pat_redomap,
    Just Names
res_red_var <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
redomap_orig_res Map VName Names
variance, -- variance of the reduce result

    -- we furthermore check that code1 is only formed by
    -- 1. statements that slice some globally-declared arrays
    --    to produce the input for the redomap, and
    -- 2. potentially some statements on which the redomap
    --    is independent; these are recorded in `code2''`
    Just (Stms GPU
code2'', Map VName (Stm GPU)
tab_inv_stm) <-
      forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
        (Names
-> Names
-> Maybe (Stms GPU, Map VName (Stm GPU))
-> Stm GPU
-> Maybe (Stms GPU, Map VName (Stm GPU))
processIndirections ([VName] -> Names
namesFromList [VName]
arrs) Names
res_red_var)
        (forall a. a -> Maybe a
Just (forall a. Seq a
Seq.empty, forall k a. Map k a
M.empty))
        Stms GPU
code1,
    -- identify load_A, load_B
    [Stm GPU]
tmp_stms <- forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Stm GPU)
tab_inv_stm) [VName]
arrs,
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [Stm GPU]
tmp_stms forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs =
      let zip_AB :: [(Stm GPU, VName, PrimType)]
zip_AB = forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Stm GPU]
tmp_stms [VName]
arrs [PrimType
map_t1_0, PrimType
map_t2_0]
          [(Stm GPU
load_A, VName
inp_A, PrimType
map_t1), (Stm GPU
load_B, VName
inp_B, PrimType
map_t2)] =
            if [Int]
var_dims forall a. Eq a => a -> a -> Bool
== [Int
0, Int
1]
              then [(Stm GPU, VName, PrimType)]
zip_AB
              else forall a. [a] -> [a]
reverse [(Stm GPU, VName, PrimType)]
zip_AB
          code2' :: Stms GPU
code2' = Stms GPU
code2'' forall a. Semigroup a => a -> a -> a
<> Stms GPU
code2
       in forall a. a -> Maybe a
Just
            ( Stms GPU
code2',
              (Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
              SubExp
common_dim,
              [Int]
var_dims,
              (Lambda GPU
map_lam, Lambda GPU
red_lam, SubExp
red_ne, VName
redomap_orig_res, forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
red_t1)
            )
matchesBlkRegTile SegSpace
_ Stms GPU
_ = forall a. Maybe a
Nothing

-- ceiled division expression
ceilDiv :: MonadBuilder m => SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv :: forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
x SubExp
y = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
x SubExp
y

mkTileMemSizes ::
  SubExp ->
  SubExp ->
  SubExp ->
  Builder
    GPU
    ( SubExp,
      SubExp,
      SubExp,
      SubExp,
      SubExp,
      SubExp,
      SubExp,
      SubExp,
      SubExp,
      SubExp,
      SubExp
    )
mkTileMemSizes :: SubExp
-> SubExp
-> SubExp
-> Builder
     GPU
     (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
      SubExp, SubExp, SubExp)
mkTileMemSizes SubExp
height_A SubExp
width_B SubExp
common_dim = do
  Name
tk_name <- [Char] -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Tk"
  Name
tx_name <- [Char] -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Tx"
  Name
ty_name <- [Char] -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Ty"
  Name
rx_name <- [Char] -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Rx"
  Name
ry_name <- [Char] -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Ry"

  (SubExp
ty, SubExp
ry) <- ([Char], [Char])
-> (Name, Name) -> SubExp -> Builder GPU (SubExp, SubExp)
getParTiles ([Char]
"Ty", [Char]
"Ry") (Name
ty_name, Name
ry_name) SubExp
height_A
  (SubExp
tx, SubExp
rx) <- ([Char], [Char])
-> (Name, Name) -> SubExp -> Builder GPU (SubExp, SubExp)
getParTiles ([Char]
"Tx", [Char]
"Rx") (Name
tx_name, Name
rx_name) SubExp
width_B
  SubExp
tk <- [Char]
-> Name
-> SubExp
-> SubExp
-> SubExp
-> BuilderT GPU (State VNameSource) SubExp
getSeqTile [Char]
"Tk" Name
tk_name SubExp
common_dim SubExp
tx SubExp
ty

  SubExp
tk_div_tx <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tk_div_tx" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
tk SubExp
tx
  SubExp
tk_div_ty <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tk_div_ty" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
tk SubExp
ty

  SubExp
tx_rx <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"TxRx" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx)
  SubExp
ty_ry <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"TyRy" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry)

  let pad_term :: TPrimExp Int64 VName
pad_term = forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk) (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry)
  -- if A not transposed, its shmem should be [ty*ry][tk]
  -- we pad to [ty*ry][tk+1] size to minimize bank conflicts
  SubExp
a_loc_sz <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"a_loc_sz"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
pad_term)
  -- if B is transposed, its shmem should be [tk][tx*rx]
  -- we pad as above, by assuming tx*rx == ty*ry >= tk
  -- ToDo: we can decrease the size by checking at this
  --       point whether A and B are transposed (or not).
  SubExp
b_loc_sz <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"b_loc_sz"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
pad_term) -- (pe64 tk * pe64 tx * pe64 rx)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx, SubExp
ty_ry, SubExp
a_loc_sz, SubExp
b_loc_sz)

mkNewSegthdLvl ::
  SubExp ->
  SubExp ->
  TPrimExp Int64 VName ->
  Builder GPU (SubExp, SubExp, SegLevel)
mkNewSegthdLvl :: SubExp
-> SubExp
-> TPrimExp Int64 VName
-> Builder GPU (SubExp, SubExp, SegLevel)
mkNewSegthdLvl SubExp
tx SubExp
ty TPrimExp Int64 VName
grid_pexp = do
  SubExp
grid_size <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"grid_size" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
grid_pexp
  SubExp
group_size <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"group_size" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)
  let segthd_lvl :: SegLevel
segthd_lvl = SegVirt -> SegLevel
SegThreadInGroup (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []))
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
grid_size, SubExp
group_size, SegLevel
segthd_lvl)

mkGidsXYF :: Builder GPU (VName, VName, VName)
mkGidsXYF :: Builder GPU (VName, VName, VName)
mkGidsXYF = do
  VName
gid_y <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_y"
  VName
gid_x <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_x"
  VName
gid_flat <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
gid_x, VName
gid_y, VName
gid_flat)

initRegShmem ::
  (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp) ->
  (PrimType, PrimType, PrimType) ->
  SegLevel ->
  SubExp ->
  Builder GPU (VName, VName, VName)
initRegShmem :: (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp)
-> (PrimType, PrimType, PrimType)
-> SegLevel
-> SubExp
-> Builder GPU (VName, VName, VName)
initRegShmem
  (SubExp
rx, SubExp
tx, SubExp
ry, SubExp
ty, SubExp
a_loc_sz, SubExp
b_loc_sz)
  (PrimType
map_t1, PrimType
map_t2, PrimType
red_t)
  SegLevel
segthd_lvl
  SubExp
red_ne = do
    -- initialize register mem with neutral elements.
    [VName]
cssss_list <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D [Char]
"cssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) forall a b. (a -> b) -> a -> b
$ \(VName, VName)
_ -> do
      VName
css_init <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"css_init" PrimType
red_t [SubExp
ry, SubExp
rx]
      VName
css <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
ry [VName
css_init] forall a b. (a -> b) -> a -> b
$ \VName
i [VName
css_merge] -> do
        VName
css' <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rx [VName
css_merge] forall a b. (a -> b) -> a -> b
$ \VName
j [VName
css_merge'] -> do
          VName
css'' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update [Char]
"css" VName
css_merge' [VName
i, VName
j] SubExp
red_ne
          forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css'']
        forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css']
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
css]
    let [VName
cssss] = [VName]
cssss_list
    -- scratch shared memory
    VName
a_loc_init <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"A_loc" PrimType
map_t1 [SubExp
a_loc_sz]
    VName
b_loc_init <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"B_loc" PrimType
map_t2 [SubExp
b_loc_sz]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
cssss, VName
a_loc_init, VName
b_loc_init)

getThdRedomapRes ::
  (SubExp, SubExp) ->
  (VName, VName) ->
  (VName, VName, VName) ->
  Builder GPU (VName, VName, VName)
getThdRedomapRes :: (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName, VName)
-> Builder GPU (VName, VName, VName)
getThdRedomapRes (SubExp
rx, SubExp
ry) (VName
ltid_x, VName
ltid_y) (VName
iii, VName
jjj, VName
redomap_res) = do
  VName
css <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"redomap_thd" VName
redomap_res [VName
ltid_y, VName
ltid_x]
  VName
ii <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ii" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
iii forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry)
  VName
jj <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jj" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
jjj forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
css, VName
ii, VName
jj)

prereqAddCode2 ::
  (VName, VName) ->
  (VName, VName, VName, VName) ->
  (VName, VName) ->
  Builder GPU ()
prereqAddCode2 :: (VName, VName)
-> (VName, VName, VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
prereqAddCode2 (VName
gtid_x, VName
gtid_y) (VName
ii, VName
i, VName
jj, VName
j) (VName
css, VName
redomap_orig_res) = do
  VName
c <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"redomap_elm" VName
css [VName
i, VName
j]
  Stm GPU
cpy_stm <- forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName
redomap_orig_res] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
c
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm GPU
cpy_stm
  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_y] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
ii forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i)
  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_x] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
jj forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
j)

-- | Tries to identify the following pattern:
--   code followed by some Screma followed by more code.
matchCodeStreamCode ::
  Stms GPU ->
  (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeStreamCode :: Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeStreamCode Stms GPU
kstms =
  let ([Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU]
code2) =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc Stm GPU
stmt ->
              case (([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc, Stm GPU
stmt) of
                (([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op (OtherOp Screma {}))) ->
                  ([Stm GPU]
cd1, forall a. a -> Maybe a
Just Stm GPU
stmt, [Stm GPU]
cd2)
                (([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Stm GPU
_) ->
                  ([Stm GPU]
cd1 forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt], forall a. Maybe a
Nothing, [Stm GPU]
cd2)
                (([Stm GPU]
cd1, Just Stm GPU
strm, [Stm GPU]
cd2), Stm GPU
_) ->
                  ([Stm GPU]
cd1, forall a. a -> Maybe a
Just Stm GPU
strm, [Stm GPU]
cd2 forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt])
          )
          ([], forall a. Maybe a
Nothing, [])
          (forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
kstms)
   in (forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code1, Maybe (Stm GPU)
screma, forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code2)

-- | Checks that all streamed arrays are variant to exacly one of
--   the two innermost parallel dimensions, and conversely, for
--   each of the two innermost parallel dimensions, there is at
--   least one streamed array variant to it. The result is the
--   number of the only variant parallel dimension for each array.
isInvarTo1of2InnerDims ::
  Names ->
  SegSpace ->
  VarianceTable ->
  [VName] ->
  Maybe [Int]
isInvarTo1of2InnerDims :: Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo1of2InnerDims Names
branch_variant SegSpace
kspace Map VName Names
variance [VName]
arrs =
  let inner_perm0 :: [Maybe Int]
inner_perm0 = forall a b. (a -> b) -> [a] -> [b]
map VName -> Maybe Int
varToOnly1of2InnerDims [VName]
arrs
      inner_perm :: [Int]
inner_perm = forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
inner_perm0
      ok1 :: Bool
ok1 = forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
0 [Int]
inner_perm Bool -> Bool -> Bool
&& forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
1 [Int]
inner_perm
      ok2 :: Bool
ok2 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Int]
inner_perm0 forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
inner_perm
   in if Bool
ok1 Bool -> Bool -> Bool
&& Bool
ok2 then forall a. a -> Maybe a
Just [Int]
inner_perm else forall a. Maybe a
Nothing
  where
    varToOnly1of2InnerDims :: VName -> Maybe Int
    varToOnly1of2InnerDims :: VName -> Maybe Int
varToOnly1of2InnerDims VName
arr = do
      (VName
j, SubExp
_) : (VName
i, SubExp
_) : [(VName, SubExp)]
_ <- forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
      let variant_to :: Names
variant_to = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
arr Map VName Names
variance
          branch_invariant :: Bool
branch_invariant =
            Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
j Names
branch_variant Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
      if Bool -> Bool
not Bool
branch_invariant
        then forall a. Maybe a
Nothing -- if i or j in branch_variant; return nothing
        else
          if VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`notNameIn` Names
variant_to
            then forall a. a -> Maybe a
Just Int
0
            else
              if VName -> Names -> Bool
nameIn VName
j Names
variant_to Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`notNameIn` Names
variant_to
                then forall a. a -> Maybe a
Just Int
1
                else forall a. Maybe a
Nothing

processIndirections ::
  Names -> -- input arrays to redomap
  Names -> -- variables on which the result of redomap depends on.
  Maybe (Stms GPU, M.Map VName (Stm GPU)) ->
  Stm GPU ->
  Maybe (Stms GPU, M.Map VName (Stm GPU))
processIndirections :: Names
-> Names
-> Maybe (Stms GPU, Map VName (Stm GPU))
-> Stm GPU
-> Maybe (Stms GPU, Map VName (Stm GPU))
processIndirections Names
arrs Names
_ Maybe (Stms GPU, Map VName (Stm GPU))
acc stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
patt StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
_)))
  | Just (Stms GPU
ss, Map VName (Stm GPU)
tab) <- Maybe (Stms GPU, Map VName (Stm GPU))
acc,
    [PatElem (TypeBase Shape NoUniqueness)
p] <- forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPU)
patt,
    VName
p_nm <- forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
p,
    VName
p_nm VName -> Names -> Bool
`nameIn` Names
arrs =
      forall a. a -> Maybe a
Just (Stms GPU
ss, forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_nm Stm GPU
stm Map VName (Stm GPU)
tab)
processIndirections Names
_ Names
res_red_var Maybe (Stms GPU, Map VName (Stm GPU))
acc stm' :: Stm GPU
stm'@(Let Pat (LetDec GPU)
patt StmAux (ExpDec GPU)
_ Exp GPU
_)
  | Just (Stms GPU
ss, Map VName (Stm GPU)
tab) <- Maybe (Stms GPU, Map VName (Stm GPU))
acc,
    [PatElem (TypeBase Shape NoUniqueness)]
ps <- forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPU)
patt,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PatElem (TypeBase Shape NoUniqueness)
p -> forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
p VName -> Names -> Bool
`notNameIn` Names
res_red_var) [PatElem (TypeBase Shape NoUniqueness)]
ps =
      forall a. a -> Maybe a
Just (Stms GPU
ss forall a. Seq a -> a -> Seq a
Seq.|> Stm GPU
stm', Map VName (Stm GPU)
tab)
  | Bool
otherwise = forall a. Maybe a
Nothing

getParTiles :: (String, String) -> (Name, Name) -> SubExp -> Builder GPU (SubExp, SubExp)
getParTiles :: ([Char], [Char])
-> (Name, Name) -> SubExp -> Builder GPU (SubExp, SubExp)
getParTiles ([Char]
t_str, [Char]
r_str) (Name
t_name, Name
r_name) SubExp
len_dim =
  case SubExp
len_dim of
    Constant (IntValue (Int64Value Int64
8)) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
se8, SubExp
se1)
    Constant (IntValue (Int64Value Int64
16)) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
se8, SubExp
se2)
    Constant (IntValue (Int64Value Int64
32)) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
se8, SubExp
se4)
    SubExp
_ -> do
      SubExp
t <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
t_str forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
t_name SizeClass
SizeTile
      SubExp
r <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
r_str forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
r_name SizeClass
SizeRegTile
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
t, SubExp
r)

getSeqTile :: String -> Name -> SubExp -> SubExp -> SubExp -> Builder GPU SubExp
getSeqTile :: [Char]
-> Name
-> SubExp
-> SubExp
-> SubExp
-> BuilderT GPU (State VNameSource) SubExp
getSeqTile [Char]
tk_str Name
tk_name SubExp
len_dim SubExp
tx SubExp
ty =
  case (SubExp
tx, SubExp
ty) of
    (Constant (IntValue (Int64Value Int64
v_x)), Constant (IntValue (Int64Value Int64
v_y))) ->
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
tk_str forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. IsValue v => v -> SubExp
constant forall a b. (a -> b) -> a -> b
$
        case SubExp
len_dim of
          Constant (IntValue (Int64Value Int64
v_d)) -> forall a. Ord a => a -> a -> a
min Int64
v_d forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
min Int64
v_x Int64
v_y
          SubExp
_ -> forall a. Ord a => a -> a -> a
min Int64
v_x Int64
v_y
    (SubExp, SubExp)
_ ->
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
tk_str forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tk_name SizeClass
SizeTile

----------------------------------------------------------------------------------------------
--- 3D Tiling (RegTiling for the outermost dimension & Block tiling for the innermost two) ---
----------------------------------------------------------------------------------------------

maxRegTile :: Int64
maxRegTile :: Int64
maxRegTile = Int64
30

mkRegTileSe :: Int64 -> SubExp
mkRegTileSe :: Int64 -> SubExp
mkRegTileSe = forall v. IsValue v => v -> SubExp
constant

variantToDim :: VarianceTable -> VName -> VName -> Bool
variantToDim :: Map VName Names -> VName -> VName -> Bool
variantToDim Map VName Names
variance VName
gid_outer VName
nm =
  VName
gid_outer forall a. Eq a => a -> a -> Bool
== VName
nm Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid_outer (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
nm Map VName Names
variance)

-- | Checks that all streamed arrays are variant to exacly one of
--   the two innermost parallel dimensions, and conversely, for
--   each of the two innermost parallel dimensions, there is at
--   least one streamed array variant to it. The result is the
--   number of the only variant parallel dimension for each array.
isInvarTo2of3InnerDims ::
  Names ->
  SegSpace ->
  VarianceTable ->
  [VName] ->
  Maybe [Int]
isInvarTo2of3InnerDims :: Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo2of3InnerDims Names
branch_variant SegSpace
kspace Map VName Names
variance [VName]
arrs =
  let inner_perm0 :: [Maybe Int]
inner_perm0 = forall a b. (a -> b) -> [a] -> [b]
map VName -> Maybe Int
varToOnly1of3InnerDims [VName]
arrs
      inner_perm :: [Int]
inner_perm = forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
inner_perm0
      ok1 :: Bool
ok1 = forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
0 [Int]
inner_perm Bool -> Bool -> Bool
&& forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
1 [Int]
inner_perm Bool -> Bool -> Bool
&& forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
2 [Int]
inner_perm
      ok2 :: Bool
ok2 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Int]
inner_perm0 forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
inner_perm
   in if Bool
ok1 Bool -> Bool -> Bool
&& Bool
ok2 then forall a. a -> Maybe a
Just [Int]
inner_perm else forall a. Maybe a
Nothing
  where
    varToOnly1of3InnerDims :: VName -> Maybe Int
    varToOnly1of3InnerDims :: VName -> Maybe Int
varToOnly1of3InnerDims VName
arr = do
      (VName
k, SubExp
_) : (VName
j, SubExp
_) : (VName
i, SubExp
_) : [(VName, SubExp)]
_ <- forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
      let variant_to :: Names
variant_to = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
arr Map VName Names
variance
          branch_invariant :: Bool
branch_invariant =
            Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
              VName -> Names -> Bool
nameIn VName
k Names
branch_variant
                Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
j Names
branch_variant
                Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
      if Bool -> Bool
not Bool
branch_invariant
        then forall a. Maybe a
Nothing -- if i or j or k in branch_variant; return nothing
        else
          if VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
j Names
variant_to Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
k Names
variant_to)
            then forall a. a -> Maybe a
Just Int
0
            else
              if VName -> Names -> Bool
nameIn VName
j Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
k Names
variant_to)
                then forall a. a -> Maybe a
Just Int
1
                else
                  if VName -> Names -> Bool
nameIn VName
k Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
j Names
variant_to)
                    then forall a. a -> Maybe a
Just Int
2
                    else forall a. Maybe a
Nothing

-- | Expects a kernel statement as argument.
--   CONDITIONS for 3D tiling optimization to fire are:
--     1. a) The kernel body can be broken into
--              scalar-code-1 ++ [Redomap stmt] ++ scalar-code-2.
--        b) The kernels has a per-thread result, and obviously
--              the result is variant to the 3rd dimension
--              (counted from innermost to outermost)
--     2. For the Redomap:
--          a) the streamed arrays are one dimensional
--          b) each of the array arguments of Redomap are variant
--              to exactly one of the three innermost-parallel dimension
--              of the kernel. This condition can be relaxed by interchanging
--              kernel dimensions whenever possible.
--     3. For scalar-code-1:
--          a) each of the statements is a slice that produces one of the
--             streamed arrays
--
-- mmBlkRegTiling :: Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
-- mmBlkRegTiling (Let pat aux (Op (SegOp (SegMap SegThread{} seg_space ts old_kbody))))
doRegTiling3D :: Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
doRegTiling3D :: Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
doRegTiling3D (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp SegOp SegLevel GPU
old_kernel)))
  | SegMap SegThread {} SegSpace
space [TypeBase Shape NoUniqueness]
kertp (KernelBody () Stms GPU
kstms [KernelResult]
kres) <- SegOp SegLevel GPU
old_kernel,
    -- build the variance table, that records, for
    -- each variable name, the variables it depends on
    Map VName Names
initial_variance <- forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space,
    Map VName Names
variance <- Map VName Names -> Stms GPU -> Map VName Names
varianceInStms Map VName Names
initial_variance Stms GPU
kstms,
    -- we get the global-thread id for the two inner dimensions,
    --   as we are probably going to use it in code generation
    (VName
gtid_x, SubExp
d_Kx) : (VName
gtid_y, SubExp
d_Ky) : (VName
gtid_z, SubExp
d_M) : [(VName, SubExp)]
rem_outer_dims_rev <- forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
    [(VName, SubExp)]
rem_outer_dims <- forall a. [a] -> [a]
reverse [(VName, SubExp)]
rem_outer_dims_rev,
    -- check that the code fits the pattern having:
    -- some `code1`, followed by one Screma SOAC, followed by some `code2`
    (Stms GPU
code1, Just Stm GPU
screma_stmt, Stms GPU
code2) <- Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeStreamCode Stms GPU
kstms,
    Let Pat (LetDec GPU)
pat_redomap StmAux (ExpDec GPU)
_ (Op Op GPU
_) <- Stm GPU
screma_stmt,
    -- checks that the Screma SOAC is actually a redomap and normalize it
    Just (SubExp
common_dim, [VName]
inp_soac_arrs, (Commutativity
_, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam)) <- Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
screma_stmt,
    Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
red_nes),
    -- assuming we have a budget of maxRegTile registers, we distribute
    -- that budget across the result of redomap and the kernel result
    Int
num_res <- forall a. Ord a => a -> a -> a
max (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres),
    Int64
reg_tile <- Int64
maxRegTile forall a. Integral a => a -> a -> a
`quot` forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_res,
    SubExp
reg_tile_se <- Int64 -> SubExp
mkRegTileSe Int64
reg_tile,
    -- check that the element-type of the map and reduce are scalars:
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall shape u. TypeBase shape u -> Bool
primType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> dec
paramDec) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam,
    [TypeBase Shape NoUniqueness]
red_res_tps <- forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> dec
paramDec forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType [TypeBase Shape NoUniqueness]
red_res_tps,
    -- checks that the input arrays to redomap are variant to
    -- exactly one of the two innermost dimensions of the kernel
    Just [Int]
_ <- Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo2of3InnerDims forall a. Monoid a => a
mempty SegSpace
space Map VName Names
variance [VName]
inp_soac_arrs,
    -- get the free variables on which the result of redomap depends on
    [PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res <- forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPU)
pat_redomap,
    Names
res_red_var <- -- variance of the reduce result
      forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName Names
variance) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res,
    forall a. Monoid a => a
mempty forall a. Eq a => a -> a -> Bool
/= Names
res_red_var,
    -- we furthermore check that code1 is only formed by
    -- 1. statements that slice some globally-declared arrays
    --    to produce the input for the redomap, and
    -- 2. potentially some statements on which the redomap
    --    is independent; these are recorded in `code2''`
    Just (Stms GPU
code2'', Map VName (Stm GPU)
arr_tab0) <-
      forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
        (Names
-> Names
-> Maybe (Stms GPU, Map VName (Stm GPU))
-> Stm GPU
-> Maybe (Stms GPU, Map VName (Stm GPU))
processIndirections ([VName] -> Names
namesFromList [VName]
inp_soac_arrs) Names
res_red_var)
        (forall a. a -> Maybe a
Just (forall a. Seq a
Seq.empty, forall k a. Map k a
M.empty))
        Stms GPU
code1,
    -- check that code1 contains exacly one slice for each of the input array to redomap
    [Stm GPU]
tmp_stms <- forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Stm GPU)
arr_tab0) [VName]
inp_soac_arrs,
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [Stm GPU]
tmp_stms forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
inp_soac_arrs,
    -- code1' <- stmsFromList $ stmsToList code1 \\ stmsToList code2'',
    Stms GPU
code2' <- Stms GPU
code2'' forall a. Semigroup a => a -> a -> a
<> Stms GPU
code2,
    -- we assume the kernel results are variant to the thrid-outer parallel dimension
    -- (for sanity sake, they should be)
    [VName]
ker_res_nms <- forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelResult -> Maybe VName
getResNm [KernelResult]
kres,
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_res_nms forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType [TypeBase Shape NoUniqueness]
kertp,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Map VName Names -> VName -> VName -> Bool
variantToDim Map VName Names
variance VName
gtid_z) [VName]
ker_res_nms = do
      -- HERE STARTS THE IMPLEMENTATION:
      (Stm GPU
new_kernel, Stms GPU
host_stms) <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
        -- host code
        -- process the z-variant arrays that need transposition;
        -- these "manifest" statements should come before the kernel
        (Map VName (Stm GPU)
tab_inn, Map VName (PrimType, Stm GPU)
tab_out) <-
          forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
            (Map VName Names
-> (VName, SubExp)
-> (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
-> (VName, Stm GPU)
-> BuilderT
     GPU
     (State VNameSource)
     (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
insertTranspose Map VName Names
variance (VName
gtid_z, SubExp
d_M))
            (forall k a. Map k a
M.empty, forall k a. Map k a
M.empty)
            forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map VName (Stm GPU)
arr_tab0

        Name
tx_name <- [Char] -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Tx"
        Name
ty_name <- [Char] -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Ty"

        SubExp
tx0 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"Tx" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tx_name SizeClass
SizeTile
        SubExp
ty0 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"Ty" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
ty_name SizeClass
SizeTile
        SubExp
ty <- [Char]
-> SubExp -> SubExp -> BuilderT GPU (State VNameSource) SubExp
limitTile [Char]
"Ty" SubExp
ty0 SubExp
d_Ky
        SubExp
tx <- [Char]
-> SubExp -> SubExp -> BuilderT GPU (State VNameSource) SubExp
limitTile [Char]
"Tx" SubExp
tx0 SubExp
d_Kx
        let rz :: SubExp
rz = SubExp
reg_tile_se

        SubExp
gridDim_x <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
d_Kx SubExp
tx
        SubExp
gridDim_y <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_y" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
d_Ky SubExp
ty
        SubExp
gridDim_z <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_z" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
d_M SubExp
rz
        let gridxyz_pexp :: TPrimExp Int64 VName
gridxyz_pexp = SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_z forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_y forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_x
        let grid_pexp :: TPrimExp Int64 VName
grid_pexp = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
gridxyz_pexp forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(VName, SubExp)]
rem_outer_dims_rev
        SubExp
grid_size <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"grid_size_tile3d" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
grid_pexp
        SubExp
group_size <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"group_size_tile3d" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)
        let segthd_lvl :: SegLevel
segthd_lvl = SegVirt -> SegLevel
SegThreadInGroup (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []))

        SubExp
count_shmem <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"count_shmem" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
rz SubExp
group_size

        VName
gid_x <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_x"
        VName
gid_y <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_y"
        VName
gid_z <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_z"
        VName
gid_flat <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"

        ---- in this binder: outer seggroup ----
        ([KernelResult]
ret_seggroup, Stms GPU
stms_seggroup) <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
          VName
ii <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ii" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_z forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rz)
          VName
jj1 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jj1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_y forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty)
          VName
jj2 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jj2" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_x forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)

          -- initialize the register arrays corresponding to the result of redomap;
          [VName]
reg_arr_nms <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D [Char]
"res" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) forall a b. (a -> b) -> a -> b
$ \(VName, VName)
_ ->
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
red_nes [TypeBase Shape NoUniqueness]
red_res_tps) forall a b. (a -> b) -> a -> b
$ \(SubExp
red_ne, TypeBase Shape NoUniqueness
red_t) -> do
              VName
css_init <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"res_init" (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
red_t) [SubExp
rz]
              VName
css <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rz [VName
css_init] forall a b. (a -> b) -> a -> b
$ \VName
i [VName
css_merge] -> do
                VName
css' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update [Char]
"css" VName
css_merge [VName
i] SubExp
red_ne
                forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css']
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes VName
css

          -- scratch the shared-memory arrays corresponding to the arrays that are
          --   input to the redomap and are invariant to the outermost parallel dimension.
          [VName]
loc_arr_nms <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall k a. Map k a -> [(k, a)]
M.toList Map VName (PrimType, Stm GPU)
tab_out) forall a b. (a -> b) -> a -> b
$ \(VName
nm, (PrimType
ptp, Stm GPU
_)) ->
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch (VName -> [Char]
baseString VName
nm forall a. [a] -> [a] -> [a]
++ [Char]
"_loc") PrimType
ptp [SubExp
rz]

          [VName]
prologue_res_list <-
            SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
common_dim ([VName]
reg_arr_nms forall a. [a] -> [a] -> [a]
++ [VName]
loc_arr_nms) forall a b. (a -> b) -> a -> b
$
              \VName
q [VName]
var_nms -> do
                let reg_arr_merge_nms :: [VName]
reg_arr_merge_nms = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) [VName]
var_nms
                let loc_arr_merge_nms :: [VName]
loc_arr_merge_nms = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) [VName]
var_nms

                -- collective copy from global to shared memory
                [VName]
loc_arr_nms' <-
                  SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
count_shmem [VName]
loc_arr_merge_nms forall a b. (a -> b) -> a -> b
$ \VName
tt [VName]
loc_arr_merge2_nms -> do
                    [VName]
loc_arr_merge2_nms' <-
                      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
loc_arr_merge2_nms (forall k a. Map k a -> [(k, a)]
M.toList Map VName (PrimType, Stm GPU)
tab_out)) forall a b. (a -> b) -> a -> b
$ \(VName
loc_Y_nm, (VName
glb_Y_nm, (PrimType
ptp_Y, Stm GPU
load_Y))) -> do
                        VName
ltid_flat <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"ltid_flat"
                        VName
ltid <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"ltid"
                        let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, SubExp
group_size)]
                        ((SubExp
res_v, SubExp
res_i), Stms GPU
stms) <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
                          VName
offs <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"offs" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
group_size forall a. Num a => a -> a -> a
* forall a. a -> TPrimExp Int64 a
le64 VName
tt)
                          VName
loc_ind <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"loc_ind" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
ltid forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
offs)
                          forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_z] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
ii forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
loc_ind)
                          let glb_ind :: VName
glb_ind = VName
gtid_z
                          SubExp
y_elm <-
                            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"y_elem"
                              forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                                (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
glb_ind forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_M)
                                ( do
                                    forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm GPU
load_Y
                                    VName
res <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"Y_elem" VName
glb_Y_nm [VName
q]
                                    forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
res]
                                )
                                (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *).
MonadBuilder m =>
TypeBase Shape NoUniqueness -> m (Exp (Rep m))
eBlank forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ptp_Y])
                          SubExp
y_ind <-
                            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"y_loc_ind"
                              forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                                (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
loc_ind forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
rz)
                                (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp VName
loc_ind forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"loc_fi" forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM)
                                (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)])
                          -- y_tp  <- subExpType y_elm
                          forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
y_elm, SubExp
y_ind)

                        let ret :: KernelResult
ret = Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns forall a. Monoid a => a
mempty (forall d. [d] -> ShapeBase d
Shape [SubExp
rz]) VName
loc_Y_nm [(forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix SubExp
res_i], SubExp
res_v)]
                        let body :: KernelBody GPU
body = forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms [KernelResult
ret]

                        [VName]
res_nms <-
                          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"Y_glb2loc" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp forall a b. (a -> b) -> a -> b
$
                            forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                              forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall a b. (a -> b) -> a -> b
$
                                forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
segthd_lvl SegSpace
segspace [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ptp_Y] KernelBody GPU
body
                        let VName
res_nm : [VName]
_ = [VName]
res_nms
                        forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
res_nm
                    forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
loc_arr_merge2_nms'

                [VName]
redomap_res <-
                  [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D [Char]
"redomap_res" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) forall a b. (a -> b) -> a -> b
$
                    \(VName
ltid_y, VName
ltid_x) -> do
                      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_y] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
jj1 forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
                      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_x] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
jj2 forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
                      [VName]
reg_arr_merge_nms_slc <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
reg_arr_merge_nms forall a b. (a -> b) -> a -> b
$ \VName
reg_arr_nm ->
                        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"res_reg_slc" VName
reg_arr_nm [VName
ltid_y, VName
ltid_x]
                      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> Result
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"redomap_guarded"
                        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Ky forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Kx)
                          ( do
                              [VName]
inp_scals_invar_outer <-
                                forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall k a. Map k a -> [(k, a)]
M.toList Map VName (Stm GPU)
tab_inn) forall a b. (a -> b) -> a -> b
$ \(VName
inp_arr_nm, Stm GPU
load_stm) -> do
                                  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm GPU
load_stm
                                  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index (VName -> [Char]
baseString VName
inp_arr_nm) VName
inp_arr_nm [VName
q]
                              -- build the loop of count R whose body is semantically the redomap code
                              [VName]
reg_arr_merge_nms' <-
                                SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
rz [VName]
reg_arr_merge_nms_slc forall a b. (a -> b) -> a -> b
$ \VName
i [VName]
reg_arr_mm_nms -> do
                                  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_z] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
ii forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i)
                                  forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM
                                    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"redomap_lam"
                                    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                                      (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
gtid_z forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_M)
                                      ( do
                                          -- read from shared memory
                                          [VName]
ys <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
loc_arr_nms' forall a b. (a -> b) -> a -> b
$ \VName
loc_arr_nm ->
                                            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"inp_reg_var2z" VName
loc_arr_nm [VName
i]
                                          [VName]
cs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
reg_arr_mm_nms forall a b. (a -> b) -> a -> b
$ \VName
reg_arr_nm ->
                                            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"res_reg_var2z" VName
reg_arr_nm [VName
i]
                                          -- here we need to put in order the scalar inputs to map:
                                          let tab_scals :: Map VName VName
tab_scals =
                                                forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
                                                  forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map VName (PrimType, Stm GPU)
tab_out) [VName]
ys
                                                    forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map VName (Stm GPU)
tab_inn) [VName]
inp_scals_invar_outer
                                          [VName]
map_inp_scals <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
inp_soac_arrs forall a b. (a -> b) -> a -> b
$ \VName
arr_nm ->
                                            case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr_nm Map VName VName
tab_scals of
                                              Maybe VName
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible case reached in tiling3D\n"
                                              Just VName
nm -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
nm
                                          Lambda GPU
map_lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam
                                          Lambda GPU
red_lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
red_lam
                                          Result
map_res_scals <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda GPU
map_lam' (forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
map_inp_scals)
                                          Result
red_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda GPU
red_lam' (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
cs forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
map_res_scals))
                                          [VName]
css <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
reg_arr_mm_nms Result
red_res) forall a b. (a -> b) -> a -> b
$ \(VName
reg_arr_nm, SubExpRes
c) ->
                                            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update (VName -> [Char]
baseString VName
reg_arr_nm) VName
reg_arr_nm [VName
i] (SubExpRes -> SubExp
resSubExp SubExpRes
c)
                                          forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
css
                                      )
                                      (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
reg_arr_mm_nms)
                              forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
reg_arr_merge_nms'
                          )
                          (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
reg_arr_merge_nms_slc)
                forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ [VName]
redomap_res forall a. [a] -> [a] -> [a]
++ [VName]
loc_arr_nms'

          -- support for non-empty code2'
          --  segmap (ltid_y < ty, ltid_x < tx) {
          --    for i < rz do
          --        res = if (ii+i < d_M && jj1+ltid_y < d_Ky && jj2 + ltid_x < d_Kx)
          --              then code2' else dummy
          --        final_res[i] = res
          let redomap_res :: [VName]
redomap_res = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) [VName]
prologue_res_list
          [VName]
epilogue_res <-
            if forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_res_nms
              Bool -> Bool -> Bool
&& [VName]
ker_res_nms forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res
              then [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap3D [Char]
"rssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
se1, SubExp
ty, SubExp
tx) forall a b. (a -> b) -> a -> b
$ \(VName
_ltid_z, VName
ltid_y, VName
ltid_x) ->
                forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [TypeBase Shape NoUniqueness]
kertp [VName]
redomap_res) forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
res_tp, VName
res) -> do
                  VName
rss_init <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"rss_init" (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
res_tp) [SubExp
rz, SubExp
se1, SubExp
se1]
                  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExpRes
varRes forall a b. (a -> b) -> a -> b
$
                    SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rz [VName
rss_init] forall a b. (a -> b) -> a -> b
$ \VName
i [VName
rss] -> do
                      let slice :: Slice SubExp
slice = forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i, forall d. d -> DimIndex d
DimFix SubExp
se0, forall d. d -> DimIndex d
DimFix SubExp
se0]
                      VName
thread_res <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"thread_res" VName
res [VName
ltid_y, VName
ltid_x, VName
i]
                      SubExp
rss' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rss" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
rss Slice SubExp
slice forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
thread_res
                      forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
rss']
              else [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap3D [Char]
"rssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
se1, SubExp
ty, SubExp
tx) forall a b. (a -> b) -> a -> b
$ \(VName
_ltid_z, VName
ltid_y, VName
ltid_x) -> do
                forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_y] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
jj1 forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
                forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_x] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
jj2 forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
                [VName]
rss_init <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase Shape NoUniqueness]
kertp forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
res_tp ->
                  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"rss_init" (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
res_tp) [SubExp
rz, SubExp
se1, SubExp
se1]
                [VName]
rss <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
rz [VName]
rss_init forall a b. (a -> b) -> a -> b
$ \VName
i [VName]
rss_merge -> do
                  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_z] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
ii forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
i)
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res [VName]
redomap_res) forall a b. (a -> b) -> a -> b
$ \(PatElem (TypeBase Shape NoUniqueness)
o_res, VName
n_res) -> do
                    VName
c <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"redomap_thd" VName
n_res [VName
ltid_y, VName
ltid_x, VName
i]
                    forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
o_res] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
c)
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
c
                  [SubExp]
res_els <-
                    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"res_elem"
                      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                        ( forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$
                            forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Ky
                              forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Kx
                              forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
le64 VName
gtid_z forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_M
                        )
                        ( do
                            forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
code2'
                            forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ker_res_nms
                        )
                        (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *).
MonadBuilder m =>
TypeBase Shape NoUniqueness -> m (Exp (Rep m))
eBlank [TypeBase Shape NoUniqueness]
kertp)
                  [SubExp]
rss' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
res_els [VName]
rss_merge) forall a b. (a -> b) -> a -> b
$ \(SubExp
res_el, VName
rs_merge) -> do
                    let slice :: Slice SubExp
slice = forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i, forall d. d -> DimIndex d
DimFix SubExp
se0, forall d. d -> DimIndex d
DimFix SubExp
se0]
                    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rss" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
rs_merge Slice SubExp
slice SubExp
res_el
                  forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
rss'
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
rss

          ----------------------------------------------------------------
          -- Finally, reshape the result arrays for the RegTileReturn  ---
          ----------------------------------------------------------------
          let regtile_ret_dims :: [(SubExp, SubExp, SubExp)]
regtile_ret_dims =
                forall a b. (a -> b) -> [a] -> [b]
map (\(VName
_, SubExp
sz) -> (SubExp
sz, SubExp
se1, SubExp
se1)) [(VName, SubExp)]
rem_outer_dims
                  forall a. [a] -> [a] -> [a]
++ [(SubExp
d_M, SubExp
se1, SubExp
rz), (SubExp
d_Ky, SubExp
ty, SubExp
se1), (SubExp
d_Kx, SubExp
tx, SubExp
se1)]

          [VName]
epilogue_res' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
epilogue_res forall a b. (a -> b) -> a -> b
$ \VName
res ->
            if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
rem_outer_dims
              then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
res
              else do
                -- Add dummy dimensions to tile to reflect the outer dimensions
                TypeBase Shape NoUniqueness
res_tp' <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
res
                let ([SubExp]
block_dims, [SubExp]
rest_dims) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
2 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
res_tp'
                    ones :: [SubExp]
ones = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const SubExp
se1) [(VName, SubExp)]
rem_outer_dims
                    new_shape :: Shape
new_shape = forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]
ones, [SubExp]
block_dims, [SubExp]
ones, [SubExp]
rest_dims]
                forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"res_reshaped" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                  ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary Shape
new_shape VName
res

          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns forall a. Monoid a => a
mempty [(SubExp, SubExp, SubExp)]
regtile_ret_dims) [VName]
epilogue_res'
        -- END (ret_seggroup, stms_seggroup) <- runBuilder $ do
        let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (forall {k} (u :: k) e. e -> Count u e
Count SubExp
grid_size) (forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size)
            level' :: SegLevel
level' = SegVirt -> Maybe KernelGrid -> SegLevel
SegGroup SegVirt
SegNoVirt (forall a. a -> Maybe a
Just KernelGrid
grid)
            space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)]
rem_outer_dims forall a. [a] -> [a] -> [a]
++ [(VName
gid_z, SubExp
gridDim_z), (VName
gid_y, SubExp
gridDim_y), (VName
gid_x, SubExp
gridDim_x)])
            kbody' :: KernelBody GPU
kbody' = forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms_seggroup [KernelResult]
ret_seggroup

        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
level' SegSpace
space' [TypeBase Shape NoUniqueness]
kertp KernelBody GPU
kbody'
      -- END (new_kernel, host_stms) <- runBuilder $ do
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Stms GPU
host_stms, Stm GPU
new_kernel)
  where
    getResNm :: KernelResult -> Maybe VName
getResNm (Returns ResultManifest
ResultMaySimplify Certs
_ (Var VName
res_nm)) = forall a. a -> Maybe a
Just VName
res_nm
    getResNm KernelResult
_ = forall a. Maybe a
Nothing

    limitTile :: String -> SubExp -> SubExp -> Builder GPU SubExp
    limitTile :: [Char]
-> SubExp -> SubExp -> BuilderT GPU (State VNameSource) SubExp
limitTile [Char]
t_str SubExp
t SubExp
d_K = forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
t_str forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int64) SubExp
t SubExp
d_K
    insertTranspose ::
      VarianceTable ->
      (VName, SubExp) ->
      (M.Map VName (Stm GPU), M.Map VName (PrimType, Stm GPU)) ->
      (VName, Stm GPU) ->
      Builder GPU (M.Map VName (Stm GPU), M.Map VName (PrimType, Stm GPU))
    insertTranspose :: Map VName Names
-> (VName, SubExp)
-> (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
-> (VName, Stm GPU)
-> BuilderT
     GPU
     (State VNameSource)
     (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
insertTranspose Map VName Names
variance (VName
gidz, SubExp
_) (Map VName (Stm GPU)
tab_inn, Map VName (PrimType, Stm GPU)
tab_out) (VName
p_nm, stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
patt StmAux (ExpDec GPU)
yy (BasicOp (Index VName
arr_nm Slice SubExp
slc))))
      | [PatElem (TypeBase Shape NoUniqueness)
p] <- forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPU)
patt,
        PrimType
ptp <- forall shape u. TypeBase shape u -> PrimType
elemType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (TypeBase Shape NoUniqueness)
p,
        VName
p_nm forall a. Eq a => a -> a -> Bool
== forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
p =
          case forall a. (a -> Bool) -> [a] -> [Int]
L.findIndices (Map VName Names -> VName -> DimIndex SubExp -> Bool
variantSliceDim Map VName Names
variance VName
gidz) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc) of
            [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_nm Stm GPU
stm Map VName (Stm GPU)
tab_inn, Map VName (PrimType, Stm GPU)
tab_out)
            Int
i : [Int]
_ -> do
              TypeBase Shape NoUniqueness
arr_tp <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr_nm
              let perm :: [Int]
perm = [Int
i forall a. Num a => a -> a -> a
+ Int
1 .. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
arr_tp forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
i]
              let arr_tr_str :: [Char]
arr_tr_str = VName -> [Char]
baseString VName
arr_nm forall a. [a] -> [a] -> [a]
++ [Char]
"_transp"
              VName
arr_tr_nm <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
arr_tr_str forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
arr_nm
              let e_ind' :: Exp GPU
e_ind' = forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_tr_nm Slice SubExp
slc
              let stm' :: Stm GPU
stm' = forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
patt StmAux (ExpDec GPU)
yy Exp GPU
e_ind'
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName (Stm GPU)
tab_inn, forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_nm (PrimType
ptp, Stm GPU
stm') Map VName (PrimType, Stm GPU)
tab_out)
    insertTranspose Map VName Names
_ (VName, SubExp)
_ (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
_ (VName, Stm GPU)
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"\nUnreachable case reached in insertTranspose case, doRegTiling3D\n"

    variantSliceDim :: VarianceTable -> VName -> DimIndex SubExp -> Bool
    variantSliceDim :: Map VName Names -> VName -> DimIndex SubExp -> Bool
variantSliceDim Map VName Names
variance VName
gidz (DimFix (Var VName
vnm)) = Map VName Names -> VName -> VName -> Bool
variantToDim Map VName Names
variance VName
gidz VName
vnm
    variantSliceDim Map VName Names
_ VName
_ DimIndex SubExp
_ = Bool
False
doRegTiling3D Stm GPU
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing