{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Perform a restricted form of block+register tiling corresponding to
--   the following pattern:
--     * a redomap is quasi-perfectly nested inside a kernel with at
--       least two parallel dimension (the perfectly nested restriction
--       is relaxed a bit to allow for SGEMM);
--     * all streamed arrays of redomap are one dimensional;
--     * all streamed arrays are variant to exacly one of the two
--       innermost parallel dimensions, and conversely for each of
--       the two innermost parallel dimensions, there is at least
--       one streamed array variant to it;
--     * the stream's result is a tuple of scalar values, which are
--       also the "thread-in-space" return of the kernel.
--     * We have further restrictions that in principle can be relaxed:
--          the redomap has exactly two array input
--          the redomap produces one scalar result
--          the kernel produces one scalar result
module Futhark.Optimise.BlkRegTiling (mmBlkRegTiling, doRegTiling3D) where

import Control.Monad.Reader
import qualified Data.List as L
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Sequence as Seq
import Futhark.IR.Kernels
import Futhark.MonadFreshNames
import Futhark.Optimise.TileLoops.Shared
import Futhark.Tools
import Futhark.Transform.Rename

mmBlkRegTiling :: Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
mmBlkRegTiling :: Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
mmBlkRegTiling (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp (SegMap SegThread {} seg_space ts old_kbody))))
  | KernelBody () Stms Kernels
kstms [Returns ResultManifest
ResultMaySimplify (Var VName
res_nm)] <- KernelBody Kernels
old_kbody,
    -- check kernel has one result of primitive type
    [Type
res_tp] <- [Type]
ts,
    Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
res_tp,
    -- build the variance table, that records, for
    -- each variable name, the variables it depends on
    Map VName Names
initial_variance <- (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> Map VName Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
seg_space,
    Map VName Names
variance <- Map VName Names -> Stms Kernels -> Map VName Names
varianceInStms Map VName Names
initial_variance Stms Kernels
kstms,
    -- check that the code fits the pattern having:
    -- some `code1`, followed by one Screma SOAC, followed by some `code2`
    (Stms Kernels
code1, Just Stm Kernels
screma_stmt, Stms Kernels
code2) <- Stms Kernels -> (Stms Kernels, Maybe (Stm Kernels), Stms Kernels)
matchCodeStreamCode Stms Kernels
kstms,
    Let Pattern Kernels
pat_redomap StmAux (ExpDec Kernels)
_ (Op Op Kernels
_) <- Stm Kernels
screma_stmt,
    -- checks that the Screma SOAC is actually a redomap and normalizes it
    Just (SubExp
common_dim, [VName]
arrs, (Commutativity
_, Lambda Kernels
red_lam, [SubExp]
red_nes, Lambda Kernels
map_lam)) <- Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
isTileableRedomap Stm Kernels
screma_stmt,
    -- check that exactly two 1D arrays are streamed thorugh redomap,
    -- and the result of redomap is one scalar
    -- !!!I need to rearrange this whole thing!!! including inp_A and inp_B
    [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 Bool -> Bool -> Bool
&& [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
    [Type
map_t1t, Type
map_t2t] <- (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec ([Param Type] -> [Type]) -> [Param Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam,
    [Type
red_t1, Type
_] <- (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec ([Param Type] -> [Type]) -> [Param Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam,
    Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
map_t1t Bool -> Bool -> Bool
&& Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
map_t2t Bool -> Bool -> Bool
&& Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
red_t1,
    PrimType
map_t1 <- Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
map_t1t,
    PrimType
map_t2 <- Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
map_t2t,
    -- checks that the input arrays to redomap are variant to
    -- exactly one of the two innermost dimensions of the kernel
    Just [Int]
var_dims <- Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo1of2InnerDims Names
forall a. Monoid a => a
mempty SegSpace
seg_space Map VName Names
variance [VName]
arrs,
    -- get the variables on which the first result of redomap depends on
    [PatElemT Type
redomap_orig_res] <- PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
Pattern Kernels
pat_redomap,
    Just Names
res_red_var <- VName -> Map VName Names -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
redomap_orig_res) Map VName Names
variance, -- variance of the reduce result

    -- we furthermore check that code1 is only formed by
    -- 1. statements that slice some globally-declared arrays
    --    to produce the input for the redomap, and
    -- 2. potentially some statements on which the redomap
    --    is independent; these are recorded in `code2''`
    Just (Stms Kernels
code2'', Map VName (Stm Kernels)
tab_inv_stm) <-
      (Maybe (Stms Kernels, Map VName (Stm Kernels))
 -> Stm Kernels -> Maybe (Stms Kernels, Map VName (Stm Kernels)))
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
-> Stms Kernels
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
        (Names
-> Names
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
-> Stm Kernels
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
processIndirections ([VName] -> Names
namesFromList [VName]
arrs) Names
res_red_var)
        ((Stms Kernels, Map VName (Stm Kernels))
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
forall a. a -> Maybe a
Just (Stms Kernels
forall a. Seq a
Seq.empty, Map VName (Stm Kernels)
forall k a. Map k a
M.empty))
        Stms Kernels
code1,
    -- identify load_A, load_B
    [Stm Kernels]
tmp_stms <- (VName -> Maybe (Stm Kernels)) -> [VName] -> [Stm Kernels]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName (Stm Kernels) -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Stm Kernels)
tab_inv_stm) [VName]
arrs,
    [Stm Kernels] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Stm Kernels]
tmp_stms Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs,
    -- [inp_A, inp_B] <- arrs,
    [(Stm Kernels, VName)]
zip_AB <- [Stm Kernels] -> [VName] -> [(Stm Kernels, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Stm Kernels]
tmp_stms [VName]
arrs,
    [(Stm Kernels
load_A, VName
inp_A), (Stm Kernels
load_B, VName
inp_B)] <- if [Int]
var_dims [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0, Int
1] then [(Stm Kernels, VName)]
zip_AB else [(Stm Kernels, VName)] -> [(Stm Kernels, VName)]
forall a. [a] -> [a]
reverse [(Stm Kernels, VName)]
zip_AB,
    -- code1' <- stmsFromList $ stmsToList code1 \\ stmsToList code2'',
    Stms Kernels
code2' <- Stms Kernels
code2'' Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
code2,
    -- we get the global-thread id for the two inner dimensions,
    --   as we are probably going to use it in code generation
    (VName
gtid_x, SubExp
width_B) : (VName
gtid_y, SubExp
height_A) : [(VName, SubExp)]
rem_outer_dims_rev <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
seg_space,
    [(VName, SubExp)]
rem_outer_dims <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
rem_outer_dims_rev,
    -- sanity check that the reduce part is not missing
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
red_nes = do
    let SubExp
red_ne : [SubExp]
_ = [SubExp]
red_nes
    Type
red_t <- SubExp -> ReaderT (Scope Kernels) (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
red_ne

    ---- in this binder: host code and outer seggroup (ie. the new kernel) ----
    (Stm Kernels
new_kernel, Stms Kernels
host_stms) <- Binder Kernels (Stm Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (Stm Kernels, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (Stm Kernels)
 -> ReaderT
      (Scope Kernels) (State VNameSource) (Stm Kernels, Stms Kernels))
-> Binder Kernels (Stm Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (Stm Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
      -- host code

      Name
tk_name <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"Tk"
      Name
tx_name <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"Tx"
      Name
ty_name <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"Ty"
      Name
rx_name <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"Rx"
      Name
ry_name <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"Ry"

      (SubExp
ty, SubExp
ry) <- (String, String)
-> (Name, Name) -> SubExp -> Binder Kernels (SubExp, SubExp)
getParTiles (String
"Ty", String
"Ry") (Name
ty_name, Name
ry_name) SubExp
height_A
      (SubExp
tx, SubExp
rx) <- (String, String)
-> (Name, Name) -> SubExp -> Binder Kernels (SubExp, SubExp)
getParTiles (String
"Tx", String
"Rx") (Name
tx_name, Name
rx_name) SubExp
width_B
      SubExp
tk <- String
-> Name -> SubExp -> SubExp -> SubExp -> Binder Kernels SubExp
getSeqTile String
"Tk" Name
tk_name SubExp
common_dim SubExp
ty SubExp
tx

      SubExp
tk_div_tx <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"tk_div_tx" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv SubExp
tk SubExp
tx
      SubExp
tk_div_ty <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"tk_div_ty" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv SubExp
tk SubExp
ty

      SubExp
tx_rx <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"TxRx" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx)
      SubExp
ty_ry <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"TyRy" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry)

      SubExp
a_loc_sz <-
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"a_loc_sz"
          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk)

      SubExp
b_loc_sz <-
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"b_loc_sz"
          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx)

      SubExp
gridDim_x <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"gridDim_x" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv SubExp
width_B SubExp
tx_rx
      SubExp
gridDim_y <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"gridDim_y" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv SubExp
height_A SubExp
ty_ry
      let gridxy_pexp :: TPrimExp Int64 VName
gridxy_pexp = SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_x
      let grid_pexp :: TPrimExp Int64 VName
grid_pexp =
            (TPrimExp Int64 VName -> SubExp -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> [SubExp] -> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 VName
x SubExp
d -> SubExp -> TPrimExp Int64 VName
pe64 SubExp
d TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
x) TPrimExp Int64 VName
gridxy_pexp ([SubExp] -> TPrimExp Int64 VName)
-> [SubExp] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
              ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
rem_outer_dims_rev
      SubExp
grid_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"grid_size" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp TPrimExp Int64 VName
grid_pexp
      SubExp
group_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"group_size" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)
      let segthd_lvl :: SegLevel
segthd_lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
grid_size) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirtFull

      VName
gid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_x"
      VName
gid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_y"
      VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"

      ---- in this binder: outer seggroup ----
      ([KernelResult]
ret_seggroup, Stms Kernels
stms_seggroup) <- Binder Kernels [KernelResult]
-> BinderT
     Kernels (State VNameSource) ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult]
 -> BinderT
      Kernels (State VNameSource) ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> BinderT
     Kernels (State VNameSource) ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
        VName
iii <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"iii" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty_ry)
        VName
jjj <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"jjj" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx_rx)

        -- initialize register mem with neutral elements.
        [VName]
cssss_list <- String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap2D String
"cssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> Binder Kernels [SubExp])
 -> Binder Kernels [VName])
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName, VName)
_ -> do
          VName
css_init <- String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch String
"css_init" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
red_t) [SubExp
ry, SubExp
rx]
          VName
css <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
ry [VName
css_init] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
css_merge] -> do
            VName
css' <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
rx [VName
css_merge] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
j [VName
css_merge'] -> do
              VName
css'' <- String
-> VName
-> [VName]
-> SubExp
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> SubExp -> m VName
update' String
"css" VName
css_merge' [VName
i, VName
j] SubExp
red_ne
              [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
css'']
            [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
css']
          [SubExp] -> Binder Kernels [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName -> SubExp
Var VName
css]
        let [VName
cssss] = [VName]
cssss_list

        VName
a_loc_init <- String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch String
"A_loc" PrimType
map_t1 [SubExp
a_loc_sz]
        VName
b_loc_init <- String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch String
"B_loc" PrimType
map_t2 [SubExp
b_loc_sz]

        let kkLoopBody :: VName -> (VName, VName, VName) -> Bool -> Binder Kernels [VName]
kkLoopBody VName
kk0 (VName
thd_res_merge, VName
a_loc_init', VName
b_loc_init') Bool
epilogue = do
              VName
kk <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"kk" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
kk0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk)
              VName
a_loc <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
ry [VName
a_loc_init'] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i0 [VName
a_loc_merge] -> do
                VName
loop_a_loc <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
tk_div_tx [VName
a_loc_merge] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
k0 [VName
a_loc_merge'] -> do
                  [VName]
scatter_a_loc <- String
-> SubExp
-> VName
-> SegLevel
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels (SubExp, SubExp))
-> Binder Kernels [VName]
segScatter2D String
"A_glb2loc" SubExp
a_loc_sz VName
a_loc_merge' SegLevel
segthd_lvl (SubExp
ty, SubExp
tx) (((VName, VName) -> Binder Kernels (SubExp, SubExp))
 -> Binder Kernels [VName])
-> ((VName, VName) -> Binder Kernels (SubExp, SubExp))
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
                    \(VName
thd_y, VName
thd_x) -> do
                      VName
k <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"k" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
thd_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)
                      VName
i <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"i" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
thd_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty)

                      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_y] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
iii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
                      VName
a_col_idx <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"A_col_idx" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
kk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k)

                      SubExp
a_elem <-
                        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"A_elem"
                          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                            ( TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
                                VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
height_A
                                  TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. if Bool
epilogue
                                    then VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
a_col_idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim
                                    else TPrimExp Bool VName
forall v. TPrimExp Bool v
true
                            )
                            ( do
                                Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
load_A
                                VName
res <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"A_elem" VName
inp_A [VName
a_col_idx]
                                [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
res]
                            )
                            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank (Type
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
map_t1])
                      SubExp
a_loc_ind <-
                        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"a_loc_ind"
                          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                            (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk)
                            ( TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk)
                                BinderT Kernels (State VNameSource) (ExpT Kernels)
-> (ExpT Kernels -> Binder Kernels [SubExp])
-> Binder Kernels [SubExp]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [SubExp]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
"loc_fi"
                                Binder Kernels [SubExp]
-> ([SubExp] -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [SubExp] -> Binder Kernels (Body Kernels)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM
                            )
                            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)])
                      (SubExp, SubExp) -> Binder Kernels (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
a_elem, SubExp
a_loc_ind)
                  [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scatter_a_loc
                [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
loop_a_loc]

              -- copy B from global to shared memory
              VName
b_loc <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
tk_div_ty [VName
b_loc_init'] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
k0 [VName
b_loc_merge] -> do
                VName
loop_b_loc <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
rx [VName
b_loc_merge] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
j0 [VName
b_loc_merge'] -> do
                  [VName]
scatter_b_loc <- String
-> SubExp
-> VName
-> SegLevel
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels (SubExp, SubExp))
-> Binder Kernels [VName]
segScatter2D
                    String
"B_glb2loc"
                    SubExp
b_loc_sz
                    VName
b_loc_merge'
                    SegLevel
segthd_lvl
                    (SubExp
ty, SubExp
tx)
                    (((VName, VName) -> Binder Kernels (SubExp, SubExp))
 -> Binder Kernels [VName])
-> ((VName, VName) -> Binder Kernels (SubExp, SubExp))
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
thd_y, VName
thd_x) -> do
                      VName
k <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"k" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
thd_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty)
                      VName
j <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"j" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
thd_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
j0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)

                      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_x] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jjj TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
j)
                      VName
b_row_idx <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"B_row_idx" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
kk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k)

                      SubExp
b_elem <-
                        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"B_elem"
                          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                            ( TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
                                VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
width_B
                                  TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. if Bool
epilogue
                                    then VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
b_row_idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim
                                    else TPrimExp Bool VName
forall v. TPrimExp Bool v
true
                            )
                            ( do
                                Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
load_B
                                VName
res <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"B_elem" VName
inp_B [VName
b_row_idx]
                                [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
res]
                            )
                            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank (Type
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
map_t2])

                      SubExp
b_loc_ind <-
                        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"b_loc_ind"
                          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                            (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk)
                            ( TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
j TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx_rx)
                                BinderT Kernels (State VNameSource) (ExpT Kernels)
-> (ExpT Kernels -> Binder Kernels [SubExp])
-> Binder Kernels [SubExp]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [SubExp]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
"loc_fi"
                                Binder Kernels [SubExp]
-> ([SubExp] -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [SubExp] -> Binder Kernels (Body Kernels)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM
                            )
                            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)])
                      (SubExp, SubExp) -> Binder Kernels (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
b_elem, SubExp
b_loc_ind)
                  [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scatter_b_loc
                [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
loop_b_loc]

              -- inner loop updating this thread's accumulator (loop k in mmm_kernels).
              VName
thd_acc <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
tk [VName
thd_res_merge] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
k [VName
acc_merge] ->
                [SubExp] -> Binder Kernels (Body Kernels)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> Binder Kernels (Body Kernels))
-> Binder Kernels [SubExp] -> Binder Kernels (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [SubExp]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
"foo"
                  (ExpT Kernels -> Binder Kernels [SubExp])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                    ( TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
                        if Bool
epilogue
                          then
                            VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
kk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k
                              TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim
                          else TPrimExp Bool VName
forall v. TPrimExp Bool v
true -- if in prologue, always compute redomap.
                    )
                    ( do
                        [VName]
reg_mem <- String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap2D String
"reg_mem" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> Binder Kernels [SubExp])
 -> Binder Kernels [VName])
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
                          \(VName
ltid_y, VName
ltid_x) -> do
                            VName
asss_init <- String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch String
"asss_init" PrimType
map_t1 [SubExp
ry]
                            VName
bsss_init <- String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch String
"bsss_init" PrimType
map_t2 [SubExp
rx]

                            VName
asss <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
ry [VName
asss_init] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
asss_merge] -> do
                              VName
a_loc_ind <-
                                String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"a_loc_ind"
                                  (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp
                                    ( VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k
                                        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk
                                    )

                              VName
asss <-
                                String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"A_loc_elem" VName
a_loc [VName
a_loc_ind]
                                  BinderT Kernels (State VNameSource) VName
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String
-> VName
-> [VName]
-> VName
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> VName -> m VName
update String
"asss" VName
asss_merge [VName
i]
                              [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
asss]

                            VName
bsss <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
rx [VName
bsss_init] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
j [VName
bsss_merge] -> do
                              VName
b_loc_ind <-
                                String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"b_loc_ind"
                                  (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp
                                    ( VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
j
                                        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx_rx
                                        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx
                                    )

                              VName
bsss <-
                                String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"B_loc_elem" VName
b_loc [VName
b_loc_ind]
                                  BinderT Kernels (State VNameSource) VName
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String
-> VName
-> [VName]
-> VName
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> VName -> m VName
update String
"bsss" VName
bsss_merge [VName
j]
                              [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
bsss]
                            [SubExp] -> Binder Kernels [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> Binder Kernels [SubExp])
-> [SubExp] -> Binder Kernels [SubExp]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName
asss, VName
bsss]

                        let [VName
asss, VName
bsss] = [VName]
reg_mem

                        -- the actual redomap.
                        [VName]
redomap_res <- String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap2D String
"redomap_res" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> Binder Kernels [SubExp])
 -> Binder Kernels [VName])
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
                          \(VName
ltid_y, VName
ltid_x) -> do
                            VName
as <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"as" VName
asss [VName
ltid_y, VName
ltid_x]
                            VName
bs <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"bs" VName
bsss [VName
ltid_y, VName
ltid_x]
                            VName
css_init <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"css_init" VName
acc_merge [VName
ltid_y, VName
ltid_x]

                            VName
css <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
ry [VName
css_init] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
css_merge] -> do
                              VName
css <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
rx [VName
css_merge] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
j [VName
css_merge'] ->
                                [SubExp] -> Binder Kernels (Body Kernels)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> Binder Kernels (Body Kernels))
-> Binder Kernels [SubExp] -> Binder Kernels (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [SubExp]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
"foo"
                                  (ExpT Kernels -> Binder Kernels [SubExp])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                                    ( TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
                                        VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
iii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y
                                          TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
height_A
                                            TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jjj TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
j TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x
                                          TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
width_B
                                    )
                                    ( do
                                        VName
a <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"a" VName
as [VName
i]
                                        VName
b <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"b" VName
bs [VName
j]
                                        VName
c <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"c" VName
css_merge' [VName
i, VName
j]

                                        VName
map_res <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"map_res"
                                        Lambda Kernels
map_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
map_lam
                                        Lambda Kernels
red_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
red_lam

                                        -- the inputs to map are supposed to be permutted with the
                                        -- inverted permutation, so as to reach the original position;
                                        -- it just so happens that the inverse of [a,b] is [b,a]
                                        let map_inp_reg :: [VName]
map_inp_reg = if [Int]
var_dims [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0, Int
1] then [VName
a, VName
b] else [VName
b, VName
a]

                                        Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                                          Lambda Kernels -> [VName] -> [VName] -> Stms Kernels
rebindLambda Lambda Kernels
map_lam' [VName]
map_inp_reg [VName
map_res]
                                            Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Lambda Kernels -> [VName] -> [VName] -> Stms Kernels
rebindLambda Lambda Kernels
red_lam' [VName
c, VName
map_res] [VName
c]

                                        VName
css <- String
-> VName
-> [VName]
-> VName
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> VName -> m VName
update String
"css" VName
css_merge' [VName
i, VName
j] VName
c

                                        [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
css]
                                    )
                                    ([SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
css_merge'])
                              [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
css]
                            [SubExp] -> Binder Kernels [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName -> SubExp
Var VName
css]

                        [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
redomap_res
                    )
                    ([SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
acc_merge])
              [VName] -> Binder Kernels [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
thd_acc, VName
a_loc, VName
b_loc]

        -- build prologue.
        VName
full_tiles <-
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"full_tiles" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
common_dim SubExp
tk
        [VName]
prologue_res_list <-
          SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forLoop' (VName -> SubExp
Var VName
full_tiles) [VName
cssss, VName
a_loc_init, VName
b_loc_init] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> Binder Kernels [VName])
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
            \VName
kk0 [VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge] -> do
              [VName]
process_full_tiles <-
                VName -> (VName, VName, VName) -> Bool -> Binder Kernels [VName]
kkLoopBody VName
kk0 (VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge) Bool
False

              [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
process_full_tiles

        let VName
prologue_res : VName
a_loc_reuse : VName
b_loc_reuse : [VName]
_ = [VName]
prologue_res_list

        -- build epilogue.
        [VName]
epilogue_res_list <- VName -> (VName, VName, VName) -> Bool -> Binder Kernels [VName]
kkLoopBody VName
full_tiles (VName
prologue_res, VName
a_loc_reuse, VName
b_loc_reuse) Bool
True

        let VName
redomap_res : [VName]
_ = [VName]
epilogue_res_list

        -- support for non-empty code2'
        --  segmap (ltid_y < ty, ltid_x < tx) {
        --    for i < ry do
        --      for j < rx do
        --        res = if (iii+ltid_y*ry+i < height_A && jjj+ltid_x*rx+j < width_B)
        --              then code2' else dummy
        --        final_res[i,j] = res
        VName
epilogue_res <-
          if PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
redomap_orig_res VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
res_nm
            then VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
redomap_res -- epilogue_res_list
            else do
              [VName]
rssss_list <- String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap2D String
"rssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> Binder Kernels [SubExp])
 -> Binder Kernels [VName])
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_y, VName
ltid_x) -> do
                VName
rss_init <- String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch String
"rss_init" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
res_tp) [SubExp
ry, SubExp
rx]
                VName
css <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"redomap_thd" VName
redomap_res [VName
ltid_y, VName
ltid_x]
                VName
ii <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"ii" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
iii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry)
                VName
jj <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"jj" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jjj TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx)
                VName
rss <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
ry [VName
rss_init] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
rss_merge] -> do
                  VName
rss' <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
rx [VName
rss_merge] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
j [VName
rss_merge'] -> do
                    VName
c <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"redomap_elm" VName
css [VName
i, VName
j]
                    Stm Kernels
cpy_stm <- [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT
     Kernels
     (State VNameSource)
     (Stm (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesM [PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
redomap_orig_res] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT
      Kernels
      (State VNameSource)
      (Stm (Lore (BinderT Kernels (State VNameSource)))))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT
     Kernels
     (State VNameSource)
     (Stm (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
c
                    Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
cpy_stm
                    [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_y] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
                    [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_x] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
j)

                    SubExp
res_el <-
                      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"res_elem"
                        (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                          ( TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
                              VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
height_A
                                TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
width_B
                          )
                          ( do
                              Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
code2'
                              [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
res_nm]
                          )
                          ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank Type
res_tp])
                    VName
rss'' <- String
-> VName
-> [VName]
-> SubExp
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> SubExp -> m VName
update' String
"rss" VName
rss_merge' [VName
i, VName
j] SubExp
res_el
                    [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
rss'']
                  [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
rss']
                [SubExp] -> Binder Kernels [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName -> SubExp
Var VName
rss]
              let VName
rssss : [VName]
_ = [VName]
rssss_list
              VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
rssss

        let regtile_ret_dims :: [(SubExp, SubExp, SubExp)]
regtile_ret_dims =
              ((VName, SubExp) -> (SubExp, SubExp, SubExp))
-> [(VName, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map (\(VName
_, SubExp
sz) -> (SubExp
sz, SubExp
se1, SubExp
se1)) [(VName, SubExp)]
rem_outer_dims
                [(SubExp, SubExp, SubExp)]
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(SubExp
height_A, SubExp
ty, SubExp
ry), (SubExp
width_B, SubExp
tx, SubExp
rx)]

        -- Add dummy dimensions to tile to reflect the outer dimensions.
        VName
epilogue_res' <-
          if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
rem_outer_dims
            then VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
epilogue_res
            else do
              Type
epilogue_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
epilogue_res
              let ([SubExp]
block_dims, [SubExp]
rest_dims) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
2 ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
epilogue_t
                  ones :: [SubExp]
ones = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> (VName, SubExp) -> SubExp
forall a b. a -> b -> a
const (SubExp -> (VName, SubExp) -> SubExp)
-> SubExp -> (VName, SubExp) -> SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [(VName, SubExp)]
rem_outer_dims
                  new_shape :: [SubExp]
new_shape = [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]
ones, [SubExp]
block_dims, [SubExp]
ones, [SubExp]
rest_dims]
              String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"res_reshaped" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew [SubExp]
new_shape) VName
epilogue_res

        [KernelResult] -> Binder Kernels [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return [[(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns [(SubExp, SubExp, SubExp)]
regtile_ret_dims VName
epilogue_res']

      let level' :: SegLevel
level' = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
grid_size) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt
          space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)]
rem_outer_dims [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_y, SubExp
gridDim_y), (VName
gid_x, SubExp
gridDim_x)])
          kbody' :: KernelBody Kernels
kbody' = BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms_seggroup [KernelResult]
ret_seggroup
      Stm Kernels -> Binder Kernels (Stm Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm Kernels -> Binder Kernels (Stm Kernels))
-> Stm Kernels -> Binder Kernels (Stm Kernels)
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
level' SegSpace
space' [Type]
ts KernelBody Kernels
kbody'
    Maybe (Stms Kernels, Stm Kernels)
-> TileM (Maybe (Stms Kernels, Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Stms Kernels, Stm Kernels)
 -> TileM (Maybe (Stms Kernels, Stm Kernels)))
-> Maybe (Stms Kernels, Stm Kernels)
-> TileM (Maybe (Stms Kernels, Stm Kernels))
forall a b. (a -> b) -> a -> b
$ (Stms Kernels, Stm Kernels) -> Maybe (Stms Kernels, Stm Kernels)
forall a. a -> Maybe a
Just (Stms Kernels
host_stms, Stm Kernels
new_kernel)
mmBlkRegTiling Stm Kernels
_ = Maybe (Stms Kernels, Stm Kernels)
-> TileM (Maybe (Stms Kernels, Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stms Kernels, Stm Kernels)
forall a. Maybe a
Nothing

ceilDiv :: MonadBinder m => SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv :: SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv SubExp
x SubExp
y = Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
x SubExp
y

scratch :: MonadBinder m => String -> PrimType -> [SubExp] -> m VName
scratch :: String -> PrimType -> [SubExp] -> m VName
scratch String
se_name PrimType
t [SubExp]
shape = String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
se_name (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
t [SubExp]
shape

-- index an array with indices given in outer_indices; any inner
-- dims of arr not indexed by outer_indices are sliced entirely
index :: MonadBinder m => String -> VName -> [VName] -> m VName
index :: String -> VName -> [VName] -> m VName
index String
se_desc VName
arr [VName]
outer_indices = do
  Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  let shape :: Shape
shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t
      inner_dims :: [SubExp]
inner_dims = Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> Shape -> Shape
forall a. ArrayShape a => Int -> a -> a
stripDims ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
outer_indices) Shape
shape
      untouched :: SubExp -> DimIndex SubExp
untouched SubExp
d = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
      inner_slices :: [DimIndex SubExp]
inner_slices = (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
untouched [SubExp]
inner_dims
      indices :: [DimIndex SubExp]
indices = (VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
outer_indices [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [DimIndex SubExp]
inner_slices
  String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
se_desc (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp
Index VName
arr [DimIndex SubExp]
indices

update :: MonadBinder m => String -> VName -> [VName] -> VName -> m VName
update :: String -> VName -> [VName] -> VName -> m VName
update String
se_desc VName
arr [VName]
indices VName
new_elem = String -> VName -> [VName] -> SubExp -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> SubExp -> m VName
update' String
se_desc VName
arr [VName]
indices (VName -> SubExp
Var VName
new_elem)

update' :: MonadBinder m => String -> VName -> [VName] -> SubExp -> m VName
update' :: String -> VName -> [VName] -> SubExp -> m VName
update' String
se_desc VName
arr [VName]
indices SubExp
new_elem =
  String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
se_desc (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> SubExp -> BasicOp
Update VName
arr ((VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
indices) SubExp
new_elem

forLoop' ::
  SubExp -> -- loop var
  [VName] -> -- loop inits
  ( VName ->
    [VName] -> -- (loop var -> loop inits -> loop body)
    Binder Kernels (Body Kernels)
  ) ->
  Binder Kernels [VName]
forLoop' :: SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forLoop' SubExp
i_bound [VName]
merge VName -> [VName] -> Binder Kernels (Body Kernels)
body = do
  VName
i <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i" -- could give this as arg to the function
  let loop_form :: LoopForm Kernels
loop_form = VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
Int64 SubExp
i_bound []

  [Type]
merge_ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
merge
  [Param (TypeBase Shape Uniqueness)]
loop_inits <- (Type
 -> BinderT
      Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> [Type]
-> BinderT
     Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
merge_t -> String
-> TypeBase Shape Uniqueness
-> BinderT
     Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"merge" (TypeBase Shape Uniqueness
 -> BinderT
      Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> TypeBase Shape Uniqueness
-> BinderT
     Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall a b. (a -> b) -> a -> b
$ Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
merge_t Uniqueness
Unique) [Type]
merge_ts

  Body Kernels
loop_body <-
    Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels)
-> Binder Kernels (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopForm Kernels
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm Kernels
loop_form (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels)
-> Binder Kernels (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope Kernels
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
loop_inits) (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$
      VName -> [VName] -> Binder Kernels (Body Kernels)
body VName
i ([VName] -> Binder Kernels (Body Kernels))
-> [VName] -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_inits

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"loop" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
    [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> Body Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] ([Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_inits ([SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)])
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
merge) LoopForm Kernels
loop_form Body Kernels
loop_body

forLoop ::
  SubExp ->
  [VName] ->
  (VName -> [VName] -> Binder Kernels (Body Kernels)) ->
  Binder Kernels VName
forLoop :: SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
i_bound [VName]
merge VName -> [VName] -> Binder Kernels (Body Kernels)
body = do
  [VName]
res_list <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forLoop' SubExp
i_bound [VName]
merge VName -> [VName] -> Binder Kernels (Body Kernels)
body
  VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> BinderT Kernels (State VNameSource) VName)
-> VName -> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. [a] -> a
head [VName]
res_list

-- given a lambda "lam", a list "new_params" of new
-- parameters which should be applied to the lambda,
-- and a VName "res_name" which the lambda result should
-- be bound to:
--   creates Stms corresponding to binding of new_params,
--   lambda body, and binding of lambda result to res_name.
rebindLambda ::
  Lambda Kernels ->
  [VName] ->
  [VName] ->
  Stms Kernels
rebindLambda :: Lambda Kernels -> [VName] -> [VName] -> Stms Kernels
rebindLambda Lambda Kernels
lam [VName]
new_params [VName]
res_names =
  [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
    ( (Ident -> VName -> Stm Kernels)
-> [Ident] -> [VName] -> [Stm Kernels]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
        ( \Ident
ident VName
new_param ->
            [Ident] -> [Ident] -> ExpT Kernels -> Stm Kernels
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
ident] (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
new_param
        )
        [Ident]
idents
        [VName]
new_params
    )
    Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Body Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms Body Kernels
lam_body
    Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
res_cpy_stms
  where
    ([Param Type]
lam_params, Body Kernels
lam_body, Type
lam_ret_type : [Type]
_) =
      (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam, Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam, Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
lam)
    idents :: [Ident]
idents =
      (Param Type -> Ident) -> [Param Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map
        (\Param Type
param -> VName -> Type -> Ident
Ident (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param) (Param Type -> Type
forall dec. Param dec -> dec
paramDec Param Type
param))
        [Param Type]
lam_params
    res_cpy_stms :: [Stm Kernels]
res_cpy_stms =
      (VName -> SubExp -> Stm Kernels)
-> [VName] -> [SubExp] -> [Stm Kernels]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
        ( \VName
res_name SubExp
lam_res ->
            [Ident] -> [Ident] -> ExpT Kernels -> Stm Kernels
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [VName -> Type -> Ident
Ident VName
res_name Type
lam_ret_type] (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
lam_res
        )
        [VName]
res_names
        [SubExp]
lam_ress
    lam_ress :: [SubExp]
lam_ress = Body Kernels -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body Kernels
lam_body

-- | Tries to identify the following pattern:
--   code followed by some Screma followed by more code.
matchCodeStreamCode ::
  Stms Kernels ->
  (Stms Kernels, Maybe (Stm Kernels), Stms Kernels)
matchCodeStreamCode :: Stms Kernels -> (Stms Kernels, Maybe (Stm Kernels), Stms Kernels)
matchCodeStreamCode Stms Kernels
kstms =
  let ([Stm Kernels]
code1, Maybe (Stm Kernels)
screma, [Stm Kernels]
code2) =
        (([Stm Kernels], Maybe (Stm Kernels), [Stm Kernels])
 -> Stm Kernels
 -> ([Stm Kernels], Maybe (Stm Kernels), [Stm Kernels]))
-> ([Stm Kernels], Maybe (Stm Kernels), [Stm Kernels])
-> [Stm Kernels]
-> ([Stm Kernels], Maybe (Stm Kernels), [Stm Kernels])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \([Stm Kernels], Maybe (Stm Kernels), [Stm Kernels])
acc Stm Kernels
stmt ->
              case (([Stm Kernels], Maybe (Stm Kernels), [Stm Kernels])
acc, Stm Kernels
stmt) of
                (([Stm Kernels]
cd1, Maybe (Stm Kernels)
Nothing, [Stm Kernels]
cd2), Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (Op (OtherOp Screma {}))) ->
                  ([Stm Kernels]
cd1, Stm Kernels -> Maybe (Stm Kernels)
forall a. a -> Maybe a
Just Stm Kernels
stmt, [Stm Kernels]
cd2)
                (([Stm Kernels]
cd1, Maybe (Stm Kernels)
Nothing, [Stm Kernels]
cd2), Stm Kernels
_) ->
                  ([Stm Kernels]
cd1 [Stm Kernels] -> [Stm Kernels] -> [Stm Kernels]
forall a. [a] -> [a] -> [a]
++ [Stm Kernels
stmt], Maybe (Stm Kernels)
forall a. Maybe a
Nothing, [Stm Kernels]
cd2)
                (([Stm Kernels]
cd1, Just Stm Kernels
strm, [Stm Kernels]
cd2), Stm Kernels
_) ->
                  ([Stm Kernels]
cd1, Stm Kernels -> Maybe (Stm Kernels)
forall a. a -> Maybe a
Just Stm Kernels
strm, [Stm Kernels]
cd2 [Stm Kernels] -> [Stm Kernels] -> [Stm Kernels]
forall a. [a] -> [a] -> [a]
++ [Stm Kernels
stmt])
          )
          ([], Maybe (Stm Kernels)
forall a. Maybe a
Nothing, [])
          (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
kstms)
   in ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
code1, Maybe (Stm Kernels)
screma, [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
code2)

-- | Checks that all streamed arrays are variant to exacly one of
--   the two innermost parallel dimensions, and conversely, for
--   each of the two innermost parallel dimensions, there is at
--   least one streamed array variant to it. The result is the
--   number of the only variant parallel dimension for each array.
isInvarTo1of2InnerDims ::
  Names ->
  SegSpace ->
  VarianceTable ->
  [VName] ->
  Maybe [Int]
isInvarTo1of2InnerDims :: Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo1of2InnerDims Names
branch_variant SegSpace
kspace Map VName Names
variance [VName]
arrs =
  let inner_perm0 :: [Maybe Int]
inner_perm0 = (VName -> Maybe Int) -> [VName] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Maybe Int
varToOnly1of2InnerDims [VName]
arrs
      inner_perm :: [Int]
inner_perm = [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
inner_perm0
      ok1 :: Bool
ok1 = Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
0 [Int]
inner_perm Bool -> Bool -> Bool
&& Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
1 [Int]
inner_perm
      ok2 :: Bool
ok2 = [Maybe Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Int]
inner_perm0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
inner_perm
   in if Bool
ok1 Bool -> Bool -> Bool
&& Bool
ok2 then [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
inner_perm else Maybe [Int]
forall a. Maybe a
Nothing
  where
    varToOnly1of2InnerDims :: VName -> Maybe Int
    varToOnly1of2InnerDims :: VName -> Maybe Int
varToOnly1of2InnerDims VName
arr = do
      (VName
j, SubExp
_) : (VName
i, SubExp
_) : [(VName, SubExp)]
_ <- [(VName, SubExp)] -> Maybe [(VName, SubExp)]
forall a. a -> Maybe a
Just ([(VName, SubExp)] -> Maybe [(VName, SubExp)])
-> [(VName, SubExp)] -> Maybe [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
      let variant_to :: Names
variant_to = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr Map VName Names
variance
          branch_invariant :: Bool
branch_invariant =
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
j Names
branch_variant Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
      if Bool -> Bool
not Bool
branch_invariant
        then Maybe Int
forall a. Maybe a
Nothing -- if i or j in branch_variant; return nothing
        else
          if VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
j Names
variant_to)
            then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0
            else
              if VName -> Names -> Bool
nameIn VName
j Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
i Names
variant_to)
                then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
1
                else Maybe Int
forall a. Maybe a
Nothing

processIndirections ::
  Names -> -- input arrays to redomap
  Names -> -- variables on which the result of redomap depends on.
  Maybe (Stms Kernels, M.Map VName (Stm Kernels)) ->
  Stm Kernels ->
  Maybe (Stms Kernels, M.Map VName (Stm Kernels))
processIndirections :: Names
-> Names
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
-> Stm Kernels
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
processIndirections Names
arrs Names
_ Maybe (Stms Kernels, Map VName (Stm Kernels))
acc stm :: Stm Kernels
stm@(Let Pattern Kernels
patt StmAux (ExpDec Kernels)
_ (BasicOp (Index VName
_ [DimIndex SubExp]
_)))
  | Just (Stms Kernels
ss, Map VName (Stm Kernels)
tab) <- Maybe (Stms Kernels, Map VName (Stm Kernels))
acc,
    [PatElemT Type
p] <- PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
Pattern Kernels
patt,
    VName
p_nm <- PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
p,
    VName -> Names -> Bool
nameIn VName
p_nm Names
arrs =
    (Stms Kernels, Map VName (Stm Kernels))
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
forall a. a -> Maybe a
Just (Stms Kernels
ss, VName
-> Stm Kernels
-> Map VName (Stm Kernels)
-> Map VName (Stm Kernels)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_nm Stm Kernels
stm Map VName (Stm Kernels)
tab)
processIndirections Names
_ Names
res_red_var Maybe (Stms Kernels, Map VName (Stm Kernels))
acc stm' :: Stm Kernels
stm'@(Let Pattern Kernels
patt StmAux (ExpDec Kernels)
_ ExpT Kernels
_)
  | Just (Stms Kernels
ss, Map VName (Stm Kernels)
tab) <- Maybe (Stms Kernels, Map VName (Stm Kernels))
acc,
    [PatElemT Type]
ps <- PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
Pattern Kernels
patt,
    (PatElemT Type -> Bool) -> [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PatElemT Type
p -> Bool -> Bool
not (VName -> Names -> Bool
nameIn (PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
p) Names
res_red_var)) [PatElemT Type]
ps =
    (Stms Kernels, Map VName (Stm Kernels))
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
forall a. a -> Maybe a
Just (Stms Kernels
ss Stms Kernels -> Stm Kernels -> Stms Kernels
forall a. Seq a -> a -> Seq a
Seq.|> Stm Kernels
stm', Map VName (Stm Kernels)
tab)
  | Bool
otherwise = Maybe (Stms Kernels, Map VName (Stm Kernels))
forall a. Maybe a
Nothing

se0 :: SubExp
se0 :: SubExp
se0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0

se1 :: SubExp
se1 :: SubExp
se1 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

se2 :: SubExp
se2 :: SubExp
se2 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
2

se4 :: SubExp
se4 :: SubExp
se4 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
4

se8 :: SubExp
se8 :: SubExp
se8 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
8

getParTiles :: (String, String) -> (Name, Name) -> SubExp -> Binder Kernels (SubExp, SubExp)
getParTiles :: (String, String)
-> (Name, Name) -> SubExp -> Binder Kernels (SubExp, SubExp)
getParTiles (String
t_str, String
r_str) (Name
t_name, Name
r_name) SubExp
len_dim =
  case SubExp
len_dim of
    Constant (IntValue (Int64Value Int64
8)) ->
      (SubExp, SubExp) -> Binder Kernels (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
se8, SubExp
se1)
    Constant (IntValue (Int64Value Int64
16)) ->
      (SubExp, SubExp) -> Binder Kernels (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
se8, SubExp
se2)
    Constant (IntValue (Int64Value Int64
32)) ->
      (SubExp, SubExp) -> Binder Kernels (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
se8, SubExp
se4)
    SubExp
_ -> do
      SubExp
t <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
t_str (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
t_name SizeClass
SizeTile
      SubExp
r <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
r_str (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
r_name SizeClass
SizeRegTile
      (SubExp, SubExp) -> Binder Kernels (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
t, SubExp
r)

getSeqTile :: String -> Name -> SubExp -> SubExp -> SubExp -> Binder Kernels SubExp
getSeqTile :: String
-> Name -> SubExp -> SubExp -> SubExp -> Binder Kernels SubExp
getSeqTile String
tk_str Name
tk_name SubExp
len_dim SubExp
ty SubExp
tx =
  case (SubExp
tx, SubExp
ty) of
    (Constant (IntValue (Int64Value Int64
v_x)), Constant (IntValue (Int64Value Int64
v_y))) ->
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
tk_str (ExpT Kernels -> Binder Kernels SubExp)
-> (Int64 -> ExpT Kernels) -> Int64 -> Binder Kernels SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels)
-> (Int64 -> BasicOp) -> Int64 -> ExpT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (Int64 -> SubExp) -> Int64 -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64 -> Binder Kernels SubExp) -> Int64 -> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
        case SubExp
len_dim of
          Constant (IntValue (Int64Value Int64
v_d)) -> Int64 -> Int64 -> Int64
forall a. Ord a => a -> a -> a
min Int64
v_d (Int64 -> Int64) -> Int64 -> Int64
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64 -> Int64
forall a. Ord a => a -> a -> a
min Int64
v_x Int64
v_y
          SubExp
_ -> Int64 -> Int64 -> Int64
forall a. Ord a => a -> a -> a
min Int64
v_x Int64
v_y
    (SubExp, SubExp)
_ ->
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
tk_str (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tk_name SizeClass
SizeTile

----------------------------------------------------------------------------------------------
--- 3D Tiling (RegTiling for the outermost dimension & Block tiling for the innermost two) ---
----------------------------------------------------------------------------------------------

maxRegTile :: Int64
maxRegTile :: Int64
maxRegTile = Int64
30

mkRegTileSe :: Int64 -> SubExp
mkRegTileSe :: Int64 -> SubExp
mkRegTileSe = Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant

variantToDim :: VarianceTable -> VName -> VName -> Bool
variantToDim :: Map VName Names -> VName -> VName -> Bool
variantToDim Map VName Names
variance VName
gid_outer VName
nm =
  VName
gid_outer VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
nm Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid_outer (Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
nm Map VName Names
variance)

-- | Checks that all streamed arrays are variant to exacly one of
--   the two innermost parallel dimensions, and conversely, for
--   each of the two innermost parallel dimensions, there is at
--   least one streamed array variant to it. The result is the
--   number of the only variant parallel dimension for each array.
isInvarTo2of3InnerDims ::
  Names ->
  SegSpace ->
  VarianceTable ->
  [VName] ->
  Maybe [Int]
isInvarTo2of3InnerDims :: Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo2of3InnerDims Names
branch_variant SegSpace
kspace Map VName Names
variance [VName]
arrs =
  let inner_perm0 :: [Maybe Int]
inner_perm0 = (VName -> Maybe Int) -> [VName] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Maybe Int
varToOnly1of3InnerDims [VName]
arrs
      inner_perm :: [Int]
inner_perm = [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
inner_perm0
      ok1 :: Bool
ok1 = Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
0 [Int]
inner_perm Bool -> Bool -> Bool
&& Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
1 [Int]
inner_perm Bool -> Bool -> Bool
&& Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
2 [Int]
inner_perm
      ok2 :: Bool
ok2 = [Maybe Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Int]
inner_perm0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
inner_perm
   in if Bool
ok1 Bool -> Bool -> Bool
&& Bool
ok2 then [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
inner_perm else Maybe [Int]
forall a. Maybe a
Nothing
  where
    varToOnly1of3InnerDims :: VName -> Maybe Int
    varToOnly1of3InnerDims :: VName -> Maybe Int
varToOnly1of3InnerDims VName
arr = do
      (VName
k, SubExp
_) : (VName
j, SubExp
_) : (VName
i, SubExp
_) : [(VName, SubExp)]
_ <- [(VName, SubExp)] -> Maybe [(VName, SubExp)]
forall a. a -> Maybe a
Just ([(VName, SubExp)] -> Maybe [(VName, SubExp)])
-> [(VName, SubExp)] -> Maybe [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
      let variant_to :: Names
variant_to = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr Map VName Names
variance
          branch_invariant :: Bool
branch_invariant =
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
              VName -> Names -> Bool
nameIn VName
k Names
branch_variant
                Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
j Names
branch_variant
                Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
      if Bool -> Bool
not Bool
branch_invariant
        then Maybe Int
forall a. Maybe a
Nothing -- if i or j or k in branch_variant; return nothing
        else
          if VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
j Names
variant_to Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
k Names
variant_to)
            then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0
            else
              if VName -> Names -> Bool
nameIn VName
j Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
k Names
variant_to)
                then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
1
                else
                  if VName -> Names -> Bool
nameIn VName
k Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
j Names
variant_to)
                    then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
2
                    else Maybe Int
forall a. Maybe a
Nothing

-- | Expects a kernel statement as argument.
--   CONDITIONS for 3D tiling optimization to fire are:
--     1. a) The kernel body can be broken into
--              scalar-code-1 ++ [Redomap stmt] ++ scalar-code-2.
--        b) The kernels has a per-thread result, and obviously
--              the result is variant to the 3rd dimension
--              (counted from innermost to outermost)
--     2. For the Redomap:
--          a) the streamed arrays are one dimensional
--          b) each of the array arguments of Redomap are variant
--              to exactly one of the three innermost-parallel dimension
--              of the kernel. This condition can be relaxed by interchanging
--              kernel dimensions whenever possible.
--     3. For scalar-code-1:
--          a) each of the statements is a slice that produces one of the
--             streamed arrays
--
-- mmBlkRegTiling :: Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
-- mmBlkRegTiling (Let pat aux (Op (SegOp (SegMap SegThread{} seg_space ts old_kbody))))
doRegTiling3D :: Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
doRegTiling3D :: Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
doRegTiling3D (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp old_kernel)))
  | SegMap SegThread {} SegSpace
space [Type]
kertp (KernelBody () Stms Kernels
kstms [KernelResult]
kres) <- SegOp SegLevel Kernels
old_kernel,
    -- build the variance table, that records, for
    -- each variable name, the variables it depends on
    Map VName Names
initial_variance <- (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> Map VName Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space,
    Map VName Names
variance <- Map VName Names -> Stms Kernels -> Map VName Names
varianceInStms Map VName Names
initial_variance Stms Kernels
kstms,
    -- we get the global-thread id for the two inner dimensions,
    --   as we are probably going to use it in code generation
    (VName
gtid_x, SubExp
d_Kx) : (VName
gtid_y, SubExp
d_Ky) : (VName
gtid_z, SubExp
d_M) : [(VName, SubExp)]
rem_outer_dims_rev <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
    [(VName, SubExp)]
rem_outer_dims <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
rem_outer_dims_rev,
    -- check that the code fits the pattern having:
    -- some `code1`, followed by one Screma SOAC, followed by some `code2`
    (Stms Kernels
code1, Just Stm Kernels
screma_stmt, Stms Kernels
code2) <- Stms Kernels -> (Stms Kernels, Maybe (Stm Kernels), Stms Kernels)
matchCodeStreamCode Stms Kernels
kstms,
    Let Pattern Kernels
pat_redomap StmAux (ExpDec Kernels)
_ (Op Op Kernels
_) <- Stm Kernels
screma_stmt,
    -- checks that the Screma SOAC is actually a redomap and normalize it
    Just (SubExp
common_dim, [VName]
inp_soac_arrs, (Commutativity
_, Lambda Kernels
red_lam, [SubExp]
red_nes, Lambda Kernels
map_lam)) <- Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
isTileableRedomap Stm Kernels
screma_stmt,
    Bool -> Bool
not ([SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
red_nes),
    -- assuming we have a budget of maxRegTile registers, we distribute
    -- that budget across the result of redomap and the kernel result
    Int
num_res <- Int -> Int -> Int
forall a. Ord a => a -> a -> a
max ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres),
    Int64
reg_tile <- Int64
maxRegTile Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
`quot` Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_res,
    SubExp
reg_tile_se <- Int64 -> SubExp
mkRegTileSe Int64
reg_tile,
    -- check that the element-type of the map and reduce are scalars:
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Param dec -> dec
paramDec) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam,
    [Type]
red_res_tps <- (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Param dec -> dec
paramDec ([Param Type] -> [Type]) -> [Param Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam,
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType [Type]
red_res_tps,
    -- checks that the input arrays to redomap are variant to
    -- exactly one of the two innermost dimensions of the kernel
    Just [Int]
_ <- Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo2of3InnerDims Names
forall a. Monoid a => a
mempty SegSpace
space Map VName Names
variance [VName]
inp_soac_arrs,
    -- get the free variables on which the result of redomap depends on
    [PatElemT Type]
redomap_orig_res <- PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
Pattern Kernels
pat_redomap,
    Names
res_red_var <- -- variance of the reduce result
      [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> Maybe Names) -> [PatElemT Type] -> [Names]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((VName -> Map VName Names -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName Names
variance) (VName -> Maybe Names)
-> (PatElemT Type -> VName) -> PatElemT Type -> Maybe Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
redomap_orig_res,
    Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
res_red_var,
    -- we furthermore check that code1 is only formed by
    -- 1. statements that slice some globally-declared arrays
    --    to produce the input for the redomap, and
    -- 2. potentially some statements on which the redomap
    --    is independent; these are recorded in `code2''`
    Just (Stms Kernels
code2'', Map VName (Stm Kernels)
arr_tab0) <-
      (Maybe (Stms Kernels, Map VName (Stm Kernels))
 -> Stm Kernels -> Maybe (Stms Kernels, Map VName (Stm Kernels)))
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
-> Stms Kernels
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
        (Names
-> Names
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
-> Stm Kernels
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
processIndirections ([VName] -> Names
namesFromList [VName]
inp_soac_arrs) Names
res_red_var)
        ((Stms Kernels, Map VName (Stm Kernels))
-> Maybe (Stms Kernels, Map VName (Stm Kernels))
forall a. a -> Maybe a
Just (Stms Kernels
forall a. Seq a
Seq.empty, Map VName (Stm Kernels)
forall k a. Map k a
M.empty))
        Stms Kernels
code1,
    -- check that code1 contains exacly one slice for each of the input array to redomap
    [Stm Kernels]
tmp_stms <- (VName -> Maybe (Stm Kernels)) -> [VName] -> [Stm Kernels]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName (Stm Kernels) -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Stm Kernels)
arr_tab0) [VName]
inp_soac_arrs,
    [Stm Kernels] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Stm Kernels]
tmp_stms Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
inp_soac_arrs,
    -- code1' <- stmsFromList $ stmsToList code1 \\ stmsToList code2'',
    Stms Kernels
code2' <- Stms Kernels
code2'' Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
code2,
    -- we assume the kernel results are variant to the thrid-outer parallel dimension
    -- (for sanity sake, they should be)
    [VName]
ker_res_nms <- (KernelResult -> Maybe VName) -> [KernelResult] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelResult -> Maybe VName
getResNm [KernelResult]
kres,
    [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_res_nms Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres,
    Pattern [] [PatElemT (LetDec Kernels)]
_ <- Pattern Kernels
pat,
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType [Type]
kertp,
    (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Map VName Names -> VName -> VName -> Bool
variantToDim Map VName Names
variance VName
gtid_z) [VName]
ker_res_nms = do
    -- HERE STARTS THE IMPLEMENTATION:
    (Stm Kernels
new_kernel, Stms Kernels
host_stms) <- Binder Kernels (Stm Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (Stm Kernels, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (Stm Kernels)
 -> ReaderT
      (Scope Kernels) (State VNameSource) (Stm Kernels, Stms Kernels))
-> Binder Kernels (Stm Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (Stm Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
      -- host code
      -- process the z-variant arrays that need transposition;
      -- these "manifest" statements should come before the kernel
      (Map VName (Stm Kernels)
tab_inn, Map VName (PrimType, Stm Kernels)
tab_out) <-
        ((Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
 -> (VName, Stm Kernels)
 -> BinderT
      Kernels
      (State VNameSource)
      (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels)))
-> (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
-> [(VName, Stm Kernels)]
-> BinderT
     Kernels
     (State VNameSource)
     (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
          (Map VName Names
-> (VName, SubExp)
-> (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
-> (VName, Stm Kernels)
-> BinderT
     Kernels
     (State VNameSource)
     (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
insertTranspose Map VName Names
variance (VName
gtid_z, SubExp
d_M))
          (Map VName (Stm Kernels)
forall k a. Map k a
M.empty, Map VName (PrimType, Stm Kernels)
forall k a. Map k a
M.empty)
          ([(VName, Stm Kernels)]
 -> BinderT
      Kernels
      (State VNameSource)
      (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels)))
-> [(VName, Stm Kernels)]
-> BinderT
     Kernels
     (State VNameSource)
     (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
forall a b. (a -> b) -> a -> b
$ Map VName (Stm Kernels) -> [(VName, Stm Kernels)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (Stm Kernels)
arr_tab0

      Name
tx_name <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"Tx"
      Name
ty_name <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"Ty"

      SubExp
tx0 <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"Tx" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tx_name SizeClass
SizeTile
      SubExp
ty0 <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"Ty" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
ty_name SizeClass
SizeTile
      SubExp
ty <- String -> SubExp -> SubExp -> Binder Kernels SubExp
limitTile String
"Ty" SubExp
ty0 SubExp
d_Ky
      SubExp
tx <- String -> SubExp -> SubExp -> Binder Kernels SubExp
limitTile String
"Tx" SubExp
tx0 SubExp
d_Kx
      let rz :: SubExp
rz = SubExp
reg_tile_se

      SubExp
gridDim_x <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"gridDim_x" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv SubExp
d_Kx SubExp
tx
      SubExp
gridDim_y <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"gridDim_y" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv SubExp
d_Ky SubExp
ty
      SubExp
gridDim_z <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"gridDim_z" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv SubExp
d_M SubExp
rz
      let gridxyz_pexp :: TPrimExp Int64 VName
gridxyz_pexp = SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_z TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_x
      let grid_pexp :: TPrimExp Int64 VName
grid_pexp = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
gridxyz_pexp TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
rem_outer_dims_rev
      SubExp
grid_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"grid_size_tile3d" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp TPrimExp Int64 VName
grid_pexp
      SubExp
group_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"group_size_tile3d" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)
      let segthd_lvl :: SegLevel
segthd_lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
grid_size) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirtFull

      SubExp
count_shmem <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"count_shmem" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (Exp (Lore m))
ceilDiv SubExp
rz SubExp
group_size

      VName
gid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_x"
      VName
gid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_y"
      VName
gid_z <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_z"
      VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"

      ---- in this binder: outer seggroup ----
      ([KernelResult]
ret_seggroup, Stms Kernels
stms_seggroup) <- Binder Kernels [KernelResult]
-> BinderT
     Kernels (State VNameSource) ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult]
 -> BinderT
      Kernels (State VNameSource) ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> BinderT
     Kernels (State VNameSource) ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
        VName
ii <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"ii" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_z TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rz)
        VName
jj1 <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"jj1" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty)
        VName
jj2 <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"jj2" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)

        -- initialize the register arrays corresponding to the result of redomap;
        [VName]
reg_arr_nms <- String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap2D String
"res" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> Binder Kernels [SubExp])
 -> Binder Kernels [VName])
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName, VName)
_ ->
          [(SubExp, Type)]
-> ((SubExp, Type) -> Binder Kernels SubExp)
-> Binder Kernels [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [Type] -> [(SubExp, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
red_nes [Type]
red_res_tps) (((SubExp, Type) -> Binder Kernels SubExp)
 -> Binder Kernels [SubExp])
-> ((SubExp, Type) -> Binder Kernels SubExp)
-> Binder Kernels [SubExp]
forall a b. (a -> b) -> a -> b
$ \(SubExp
red_ne, Type
red_t) -> do
            VName
css_init <- String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch String
"res_init" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
red_t) [SubExp
rz]
            VName
css <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
rz [VName
css_init] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
css_merge] -> do
              VName
css' <- String
-> VName
-> [VName]
-> SubExp
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> SubExp -> m VName
update' String
"css" VName
css_merge [VName
i] SubExp
red_ne
              [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
css']
            SubExp -> Binder Kernels SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Binder Kernels SubExp)
-> SubExp -> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
css

        -- scratch the shared-memory arrays corresponding to the arrays that are
        --   input to the redomap and are invariant to the outermost parallel dimension.
        [VName]
loc_arr_nms <- [(VName, (PrimType, Stm Kernels))]
-> ((VName, (PrimType, Stm Kernels))
    -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Map VName (PrimType, Stm Kernels)
-> [(VName, (PrimType, Stm Kernels))]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (PrimType, Stm Kernels)
tab_out) (((VName, (PrimType, Stm Kernels))
  -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> ((VName, (PrimType, Stm Kernels))
    -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
nm, (PrimType
ptp, Stm Kernels
_)) ->
          String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch (VName -> String
baseString VName
nm String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_loc") PrimType
ptp [SubExp
rz]

        [VName]
prologue_res_list <-
          SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forLoop' SubExp
common_dim ([VName]
reg_arr_nms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
loc_arr_nms) ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> Binder Kernels [VName])
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
            \VName
q [VName]
var_nms -> do
              let reg_arr_merge_nms :: [VName]
reg_arr_merge_nms = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) [VName]
var_nms
              let loc_arr_merge_nms :: [VName]
loc_arr_merge_nms = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) [VName]
var_nms

              -- collective copy from global to shared memory
              [VName]
loc_arr_nms' <-
                SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forLoop' SubExp
count_shmem [VName]
loc_arr_merge_nms ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> Binder Kernels [VName])
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \VName
tt [VName]
loc_arr_merge2_nms -> do
                  [VName]
loc_arr_merge2_nms' <-
                    [(VName, (VName, (PrimType, Stm Kernels)))]
-> ((VName, (VName, (PrimType, Stm Kernels)))
    -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName]
-> [(VName, (PrimType, Stm Kernels))]
-> [(VName, (VName, (PrimType, Stm Kernels)))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
loc_arr_merge2_nms (Map VName (PrimType, Stm Kernels)
-> [(VName, (PrimType, Stm Kernels))]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (PrimType, Stm Kernels)
tab_out)) (((VName, (VName, (PrimType, Stm Kernels)))
  -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> ((VName, (VName, (PrimType, Stm Kernels)))
    -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
loc_Y_nm, (VName
glb_Y_nm, (PrimType
ptp_Y, Stm Kernels
load_Y))) -> do
                      VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
                      VName
ltid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid"
                      let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, SubExp
group_size)]
                      ((SubExp
res_v, SubExp
res_i), Stms Kernels
stms) <- Binder Kernels (SubExp, SubExp)
-> BinderT
     Kernels (State VNameSource) ((SubExp, SubExp), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (SubExp, SubExp)
 -> BinderT
      Kernels (State VNameSource) ((SubExp, SubExp), Stms Kernels))
-> Binder Kernels (SubExp, SubExp)
-> BinderT
     Kernels (State VNameSource) ((SubExp, SubExp), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
                        VName
offs <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"offs" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
group_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
tt)
                        VName
loc_ind <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"loc_ind" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
offs)
                        [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_z] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
loc_ind)
                        let glb_ind :: VName
glb_ind = VName
gtid_z
                        SubExp
y_elm <-
                          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"y_elem"
                            (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                              (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
glb_ind TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_M)
                              ( do
                                  Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
load_Y
                                  VName
res <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"Y_elem" VName
glb_Y_nm [VName
q]
                                  [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
res]
                              )
                              ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank (Type
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ptp_Y])
                        SubExp
y_ind <-
                          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"y_loc_ind"
                            (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                              (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
loc_ind TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
rz)
                              (VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp VName
loc_ind BinderT Kernels (State VNameSource) (ExpT Kernels)
-> (ExpT Kernels -> Binder Kernels [SubExp])
-> Binder Kernels [SubExp]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [SubExp]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
"loc_fi" Binder Kernels [SubExp]
-> ([SubExp] -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [SubExp] -> Binder Kernels (Body Kernels)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM)
                              ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)])
                        --y_tp  <- subExpType y_elm
                        (SubExp, SubExp) -> Binder Kernels (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
y_elm, SubExp
y_ind)

                      let ret :: KernelResult
ret = Shape -> VName -> [([DimIndex SubExp], SubExp)] -> KernelResult
WriteReturns ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
rz]) VName
loc_Y_nm [([SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
res_i], SubExp
res_v)]
                      let body :: KernelBody Kernels
body = BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms [KernelResult
ret]

                      [VName]
res_nms <-
                        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"Y_glb2loc" (ExpT Kernels -> Binder Kernels [VName])
-> (ExpT Kernels
    -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> Binder Kernels [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Exp lore -> m (Exp lore)
renameExp (ExpT Kernels -> Binder Kernels [VName])
-> ExpT Kernels -> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
                          Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
segthd_lvl SegSpace
segspace [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ptp_Y] KernelBody Kernels
body
                      let VName
res_nm : [VName]
_ = [VName]
res_nms
                      VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
res_nm
                  [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
loc_arr_merge2_nms'

              [VName]
redomap_res <-
                String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap2D String
"redomap_res" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> Binder Kernels [SubExp])
 -> Binder Kernels [VName])
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
                  \(VName
ltid_y, VName
ltid_x) -> do
                    [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_y] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
                    [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_x] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj2 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
                    [VName]
reg_arr_merge_nms_slc <- [VName]
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
reg_arr_merge_nms ((VName -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \VName
reg_arr_nm ->
                      String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"res_reg_slc" VName
reg_arr_nm [VName
ltid_y, VName
ltid_x]
                    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [SubExp]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
"redomap_guarded"
                      (ExpT Kernels -> Binder Kernels [SubExp])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                        (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Ky TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Kx)
                        ( do
                            [VName]
inp_scals_invar_outer <-
                              [(VName, Stm Kernels)]
-> ((VName, Stm Kernels)
    -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Map VName (Stm Kernels) -> [(VName, Stm Kernels)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (Stm Kernels)
tab_inn) (((VName, Stm Kernels)
  -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> ((VName, Stm Kernels)
    -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
inp_arr_nm, Stm Kernels
load_stm) -> do
                                Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
load_stm
                                String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index (VName -> String
baseString VName
inp_arr_nm) VName
inp_arr_nm [VName
q]
                            -- build the loop of count R whose body is semantically the redomap code
                            [VName]
reg_arr_merge_nms' <-
                              SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forLoop' SubExp
rz [VName]
reg_arr_merge_nms_slc ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> Binder Kernels [VName])
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \VName
i [VName]
reg_arr_mm_nms -> do
                                [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_z] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
                                [SubExp] -> Binder Kernels (Body Kernels)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> Binder Kernels (Body Kernels))
-> Binder Kernels [SubExp] -> Binder Kernels (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [SubExp]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
"redomap_lam"
                                  (ExpT Kernels -> Binder Kernels [SubExp])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                                    (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_z TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_M)
                                    ( do
                                        -- read from shared memory
                                        [VName]
ys <- [VName]
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
loc_arr_nms' ((VName -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \VName
loc_arr_nm ->
                                          String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"inp_reg_var2z" VName
loc_arr_nm [VName
i]
                                        [VName]
cs <- [VName]
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
reg_arr_mm_nms ((VName -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \VName
reg_arr_nm ->
                                          String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"res_reg_var2z" VName
reg_arr_nm [VName
i]
                                        -- here we need to put in order the scalar inputs to map:
                                        let tab_scals :: Map VName VName
tab_scals =
                                              [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$
                                                [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, (PrimType, Stm Kernels)) -> VName)
-> [(VName, (PrimType, Stm Kernels))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (PrimType, Stm Kernels)) -> VName
forall a b. (a, b) -> a
fst ([(VName, (PrimType, Stm Kernels))] -> [VName])
-> [(VName, (PrimType, Stm Kernels))] -> [VName]
forall a b. (a -> b) -> a -> b
$ Map VName (PrimType, Stm Kernels)
-> [(VName, (PrimType, Stm Kernels))]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (PrimType, Stm Kernels)
tab_out) [VName]
ys
                                                  [(VName, VName)] -> [(VName, VName)] -> [(VName, VName)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, Stm Kernels) -> VName)
-> [(VName, Stm Kernels)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Stm Kernels) -> VName
forall a b. (a, b) -> a
fst ([(VName, Stm Kernels)] -> [VName])
-> [(VName, Stm Kernels)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Map VName (Stm Kernels) -> [(VName, Stm Kernels)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (Stm Kernels)
tab_inn) [VName]
inp_scals_invar_outer
                                        [VName]
map_inp_scals <- [VName]
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
inp_soac_arrs ((VName -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \VName
arr_nm ->
                                          case VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr_nm Map VName VName
tab_scals of
                                            Maybe VName
Nothing -> String -> BinderT Kernels (State VNameSource) VName
forall a. HasCallStack => String -> a
error String
"Impossible case reached in tiling3D\n"
                                            Just VName
nm -> VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
nm
                                        [VName]
map_res_scals <- [Type]
-> (Type -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam) ((Type -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> (Type -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \Type
_ -> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"map_res"
                                        Lambda Kernels
map_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
map_lam
                                        Lambda Kernels
red_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
red_lam
                                        Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                                          Lambda Kernels -> [VName] -> [VName] -> Stms Kernels
rebindLambda Lambda Kernels
map_lam' [VName]
map_inp_scals [VName]
map_res_scals
                                            Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Lambda Kernels -> [VName] -> [VName] -> Stms Kernels
rebindLambda Lambda Kernels
red_lam' ([VName]
cs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
map_res_scals) [VName]
cs
                                        [VName]
css <- [(VName, VName)]
-> ((VName, VName) -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
reg_arr_mm_nms [VName]
cs) (((VName, VName) -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
reg_arr_nm, VName
c) ->
                                          String
-> VName
-> [VName]
-> VName
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> VName -> m VName
update (VName -> String
baseString VName
reg_arr_nm) VName
reg_arr_nm [VName
i] VName
c
                                        [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
css
                                    )
                                    ([SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
reg_arr_mm_nms)
                            [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
reg_arr_merge_nms'
                        )
                        ([SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
reg_arr_merge_nms_slc)
              [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [VName]
redomap_res [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
loc_arr_nms'

        -- support for non-empty code2'
        --  segmap (ltid_y < ty, ltid_x < tx) {
        --    for i < rz do
        --        res = if (ii+i < d_M && jj1+ltid_y < d_Ky && jj2 + ltid_x < d_Kx)
        --              then code2' else dummy
        --        final_res[i] = res
        let redomap_res :: [VName]
redomap_res = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) [VName]
prologue_res_list
        [VName]
epilogue_res <-
          if [PatElemT Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElemT Type]
redomap_orig_res Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_res_nms
            Bool -> Bool -> Bool
&& [VName]
ker_res_nms [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElemT Type -> VName) -> [PatElemT Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT Type]
redomap_orig_res
            then -- all (\ (a,b) -> patElemName a == b ) $ zip redomap_orig_res ker_res_nms
            String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap3D String
"rssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
se1, SubExp
ty, SubExp
tx) (((VName, VName, VName) -> Binder Kernels [SubExp])
 -> Binder Kernels [VName])
-> ((VName, VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
_ltid_z, VName
ltid_y, VName
ltid_x) ->
              [(Type, VName)]
-> ((Type, VName) -> Binder Kernels SubExp)
-> Binder Kernels [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
kertp [VName]
redomap_res) (((Type, VName) -> Binder Kernels SubExp)
 -> Binder Kernels [SubExp])
-> ((Type, VName) -> Binder Kernels SubExp)
-> Binder Kernels [SubExp]
forall a b. (a -> b) -> a -> b
$ \(Type
res_tp, VName
res) -> do
                VName
rss_init <- String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch String
"rss_init" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
res_tp) [SubExp
rz, SubExp
se1, SubExp
se1]
                (VName -> SubExp)
-> BinderT Kernels (State VNameSource) VName
-> Binder Kernels SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (BinderT Kernels (State VNameSource) VName
 -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) VName
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
                  SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forLoop SubExp
rz [VName
rss_init] ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> BinderT Kernels (State VNameSource) VName)
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
rss] -> do
                    let slice :: [DimIndex SubExp]
slice = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
se0, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
se0]
                    VName
thread_res <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"thread_res" VName
res [VName
ltid_y, VName
ltid_x, VName
i]
                    SubExp
rss' <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rss" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> SubExp -> BasicOp
Update VName
rss [DimIndex SubExp]
slice (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
thread_res
                    [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
rss']
            else String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap3D String
"rssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
se1, SubExp
ty, SubExp
tx) (((VName, VName, VName) -> Binder Kernels [SubExp])
 -> Binder Kernels [VName])
-> ((VName, VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
_ltid_z, VName
ltid_y, VName
ltid_x) -> do
              [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_y] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
              [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_x] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj2 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
              [VName]
rss_init <- [Type]
-> (Type -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
kertp ((Type -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> (Type -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \Type
res_tp ->
                String
-> PrimType
-> [SubExp]
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> PrimType -> [SubExp] -> m VName
scratch String
"rss_init" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
res_tp) [SubExp
rz, SubExp
se1, SubExp
se1]
              [VName]
rss <- SubExp
-> [VName]
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forLoop' SubExp
rz [VName]
rss_init ((VName -> [VName] -> Binder Kernels (Body Kernels))
 -> Binder Kernels [VName])
-> (VName -> [VName] -> Binder Kernels (Body Kernels))
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \VName
i [VName]
rss_merge -> do
                [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_z] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
                [(PatElemT Type, VName)]
-> ((PatElemT Type, VName)
    -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT Type] -> [VName] -> [(PatElemT Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT Type]
redomap_orig_res [VName]
redomap_res) (((PatElemT Type, VName)
  -> BinderT Kernels (State VNameSource) VName)
 -> BinderT Kernels (State VNameSource) ())
-> ((PatElemT Type, VName)
    -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT Type
o_res, VName
n_res) -> do
                  VName
c <- String
-> VName -> [VName] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> [VName] -> m VName
index String
"redomap_thd" VName
n_res [VName
ltid_y, VName
ltid_x, VName
i]
                  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
o_res] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
c)
                  VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
c
                [SubExp]
res_els <-
                  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [SubExp]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
"res_elem"
                    (ExpT Kernels -> Binder Kernels [SubExp])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                      ( TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
                          VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Ky
                            TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Kx
                            TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_z TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_M
                      )
                      ( do
                          Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
code2'
                          [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ker_res_nms
                      )
                      ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
kertp)
                [SubExp]
rss' <- [(SubExp, VName)]
-> ((SubExp, VName) -> Binder Kernels SubExp)
-> Binder Kernels [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
res_els [VName]
rss_merge) (((SubExp, VName) -> Binder Kernels SubExp)
 -> Binder Kernels [SubExp])
-> ((SubExp, VName) -> Binder Kernels SubExp)
-> Binder Kernels [SubExp]
forall a b. (a -> b) -> a -> b
$ \(SubExp
res_el, VName
rs_merge) -> do
                  let slice :: [DimIndex SubExp]
slice = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
se0, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
se0]
                  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rss" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> SubExp -> BasicOp
Update VName
rs_merge [DimIndex SubExp]
slice SubExp
res_el
                [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp]
rss'
              [SubExp] -> Binder Kernels [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> Binder Kernels [SubExp])
-> [SubExp] -> Binder Kernels [SubExp]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
rss

        ----------------------------------------------------------------
        -- Finally, reshape the result arrays for the RegTileReturn  ---
        ----------------------------------------------------------------
        let regtile_ret_dims :: [(SubExp, SubExp, SubExp)]
regtile_ret_dims =
              ((VName, SubExp) -> (SubExp, SubExp, SubExp))
-> [(VName, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map (\(VName
_, SubExp
sz) -> (SubExp
sz, SubExp
se1, SubExp
se1)) [(VName, SubExp)]
rem_outer_dims
                [(SubExp, SubExp, SubExp)]
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(SubExp
d_M, SubExp
se1, SubExp
rz), (SubExp
d_Ky, SubExp
ty, SubExp
se1), (SubExp
d_Kx, SubExp
tx, SubExp
se1)]

        [VName]
epilogue_res' <- [VName]
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
epilogue_res ((VName -> BinderT Kernels (State VNameSource) VName)
 -> Binder Kernels [VName])
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ \VName
res ->
          if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
rem_outer_dims
            then VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
res
            else do
              -- Add dummy dimensions to tile to reflect the outer dimensions
              Type
res_tp' <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
res
              let ([SubExp]
block_dims, [SubExp]
rest_dims) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
2 ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
res_tp'
                  ones :: [SubExp]
ones = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> (VName, SubExp) -> SubExp
forall a b. a -> b -> a
const SubExp
se1) [(VName, SubExp)]
rem_outer_dims
                  new_shape :: [SubExp]
new_shape = [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]
ones, [SubExp]
block_dims, [SubExp]
ones, [SubExp]
rest_dims]
              String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"res_reshaped" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew [SubExp]
new_shape) VName
res

        [KernelResult] -> Binder Kernels [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return ([KernelResult] -> Binder Kernels [KernelResult])
-> [KernelResult] -> Binder Kernels [KernelResult]
forall a b. (a -> b) -> a -> b
$ (VName -> KernelResult) -> [VName] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns [(SubExp, SubExp, SubExp)]
regtile_ret_dims) [VName]
epilogue_res'
      -- END (ret_seggroup, stms_seggroup) <- runBinder $ do
      let level' :: SegLevel
level' = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
grid_size) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt
          space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)]
rem_outer_dims [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_z, SubExp
gridDim_z), (VName
gid_y, SubExp
gridDim_y), (VName
gid_x, SubExp
gridDim_x)])
          kbody' :: KernelBody Kernels
kbody' = BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms_seggroup [KernelResult]
ret_seggroup

      Stm Kernels -> Binder Kernels (Stm Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm Kernels -> Binder Kernels (Stm Kernels))
-> Stm Kernels -> Binder Kernels (Stm Kernels)
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
level' SegSpace
space' [Type]
kertp KernelBody Kernels
kbody'
    -- END (new_kernel, host_stms) <- runBinder $ do
    Maybe (Stms Kernels, Stm Kernels)
-> TileM (Maybe (Stms Kernels, Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Stms Kernels, Stm Kernels)
 -> TileM (Maybe (Stms Kernels, Stm Kernels)))
-> Maybe (Stms Kernels, Stm Kernels)
-> TileM (Maybe (Stms Kernels, Stm Kernels))
forall a b. (a -> b) -> a -> b
$ (Stms Kernels, Stm Kernels) -> Maybe (Stms Kernels, Stm Kernels)
forall a. a -> Maybe a
Just (Stms Kernels
host_stms, Stm Kernels
new_kernel)
  where
    getResNm :: KernelResult -> Maybe VName
getResNm (Returns ResultManifest
ResultMaySimplify (Var VName
res_nm)) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
res_nm
    getResNm KernelResult
_ = Maybe VName
forall a. Maybe a
Nothing

    limitTile :: String -> SubExp -> SubExp -> Binder Kernels SubExp
    limitTile :: String -> SubExp -> SubExp -> Binder Kernels SubExp
limitTile String
t_str SubExp
t SubExp
d_K = String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
t_str (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int64) SubExp
t SubExp
d_K
    insertTranspose ::
      VarianceTable ->
      (VName, SubExp) ->
      (M.Map VName (Stm Kernels), M.Map VName (PrimType, Stm Kernels)) ->
      (VName, Stm Kernels) ->
      Binder Kernels (M.Map VName (Stm Kernels), M.Map VName (PrimType, Stm Kernels))
    insertTranspose :: Map VName Names
-> (VName, SubExp)
-> (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
-> (VName, Stm Kernels)
-> BinderT
     Kernels
     (State VNameSource)
     (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
insertTranspose Map VName Names
variance (VName
gidz, SubExp
_) (Map VName (Stm Kernels)
tab_inn, Map VName (PrimType, Stm Kernels)
tab_out) (VName
p_nm, stm :: Stm Kernels
stm@(Let Pattern Kernels
patt StmAux (ExpDec Kernels)
yy (BasicOp (Index VName
arr_nm [DimIndex SubExp]
slc))))
      | [PatElemT Type
p] <- PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
Pattern Kernels
patt,
        PrimType
ptp <- Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (Type -> PrimType) -> Type -> PrimType
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT Type
p,
        VName
p_nm VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
p =
        case (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> [Int]
forall a. (a -> Bool) -> [a] -> [Int]
L.findIndices (Map VName Names -> VName -> DimIndex SubExp -> Bool
variantSliceDim Map VName Names
variance VName
gidz) [DimIndex SubExp]
slc of
          [] -> (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
-> BinderT
     Kernels
     (State VNameSource)
     (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
-> Stm Kernels
-> Map VName (Stm Kernels)
-> Map VName (Stm Kernels)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_nm Stm Kernels
stm Map VName (Stm Kernels)
tab_inn, Map VName (PrimType, Stm Kernels)
tab_out)
          Int
i : [Int]
_ -> do
            Type
arr_tp <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr_nm
            let perm :: [Int]
perm = [Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
arr_tp Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
i]
            let arr_tr_str :: String
arr_tr_str = VName -> String
baseString VName
arr_nm String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_transp"
            VName
arr_tr_nm <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
arr_tr_str (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
arr_nm
            let e_ind' :: ExpT Kernels
e_ind' = BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp
Index VName
arr_tr_nm [DimIndex SubExp]
slc
            let stm' :: Stm Kernels
stm' = Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
patt StmAux (ExpDec Kernels)
yy ExpT Kernels
e_ind'
            (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
-> BinderT
     Kernels
     (State VNameSource)
     (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Map VName (Stm Kernels)
tab_inn, VName
-> (PrimType, Stm Kernels)
-> Map VName (PrimType, Stm Kernels)
-> Map VName (PrimType, Stm Kernels)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_nm (PrimType
ptp, Stm Kernels
stm') Map VName (PrimType, Stm Kernels)
tab_out)
    insertTranspose Map VName Names
_ (VName, SubExp)
_ (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
_ (VName, Stm Kernels)
_ = String
-> BinderT
     Kernels
     (State VNameSource)
     (Map VName (Stm Kernels), Map VName (PrimType, Stm Kernels))
forall a. HasCallStack => String -> a
error String
"\nUnreachable case reached in insertTranspose case, doRegTiling3D\n"

    variantSliceDim :: VarianceTable -> VName -> DimIndex SubExp -> Bool
    variantSliceDim :: Map VName Names -> VName -> DimIndex SubExp -> Bool
variantSliceDim Map VName Names
variance VName
gidz (DimFix (Var VName
vnm)) = Map VName Names -> VName -> VName -> Bool
variantToDim Map VName Names
variance VName
gidz VName
vnm
    variantSliceDim Map VName Names
_ VName
_ DimIndex SubExp
_ = Bool
False
doRegTiling3D Stm Kernels
_ = Maybe (Stms Kernels, Stm Kernels)
-> TileM (Maybe (Stms Kernels, Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stms Kernels, Stm Kernels)
forall a. Maybe a
Nothing