{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.BlkRegTiling (mmBlkRegTiling, doRegTiling3D) where
import Control.Monad
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Sequence qualified as Seq
import Futhark.IR.GPU
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.MonadFreshNames
import Futhark.Optimise.TileLoops.Shared
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
se0 :: SubExp
se0 :: SubExp
se0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
se1 :: SubExp
se1 :: SubExp
se1 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
se2 :: SubExp
se2 :: SubExp
se2 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
2
se4 :: SubExp
se4 :: SubExp
se4 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
4
se8 :: SubExp
se8 :: SubExp
se8 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
8
isInnerCoal :: Env -> VName -> Stm GPU -> Bool
isInnerCoal :: Env -> VName -> Stm GPU -> Bool
isInnerCoal (WithEnv
_, IxFnEnv
ixfn_env) VName
slc_X (Let (Pat [PatElem (LetDec GPU)
pe]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
x Slice SubExp
_)))
| VName
slc_X VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
PatElem (LetDec GPU)
pe,
Maybe LMAD
Nothing <- VName -> IxFnEnv -> Maybe LMAD
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x IxFnEnv
ixfn_env =
Bool
True
isInnerCoal (WithEnv
_, IxFnEnv
ixfn_env) VName
slc_X (Let (Pat [PatElem (LetDec GPU)
pe]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
x Slice SubExp
_)))
| VName
slc_X VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
PatElem (LetDec GPU)
pe,
Just LMAD
lmad <- VName -> IxFnEnv -> Maybe LMAD
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x IxFnEnv
ixfn_env =
LMAD -> Bool
innerHasStride1 LMAD
lmad
where
innerHasStride1 :: LMAD -> Bool
innerHasStride1 LMAD
lmad =
let lmad_dims :: [LMADDim (TPrimExp Int64 VName)]
lmad_dims = LMAD -> [LMADDim (TPrimExp Int64 VName)]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD
lmad
stride :: TPrimExp Int64 VName
stride = LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
LMAD.ldStride (LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [LMADDim (TPrimExp Int64 VName)] -> LMADDim (TPrimExp Int64 VName)
forall a. HasCallStack => [a] -> a
last [LMADDim (TPrimExp Int64 VName)]
lmad_dims
in TPrimExp Int64 VName
stride TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
isInnerCoal Env
_ VName
_ Stm GPU
_ =
[Char] -> Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"kkLoopBody.isInnerCoal: not an error, but I would like to know why!"
scratch :: (MonadBuilder m) => String -> PrimType -> [SubExp] -> m VName
scratch :: forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
se_name PrimType
t [SubExp]
shape = [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
se_name (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
t [SubExp]
shape
kkLoopBody ::
Env ->
( (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
SegLevel,
[Int],
(VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU)
) ->
VName ->
(VName, VName, VName) ->
Bool ->
Builder GPU [VName]
kkLoopBody :: Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody
Env
env
( (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
_tk_div_ty, SubExp
tx_rx),
SegLevel
segthd_lvl,
[Int]
var_dims,
(VName
gtid_x, SubExp
width_B, VName
gtid_y, SubExp
height_A, SubExp
common_dim),
(VName
iii, VName
jjj),
(Stm GPU
load_A, VName
inp_A, PrimType
pt_A, Stm GPU
load_B, VName
inp_B, PrimType
pt_B),
(Lambda GPU
map_lam, Lambda GPU
red_lam)
)
VName
kk0
(VName
thd_res_merge, VName
a_loc_init', VName
b_loc_init')
Bool
epilogue = do
let (PrimType
map_t1, PrimType
map_t2) = (PrimType
pt_A, PrimType
pt_B)
VName
kk <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"kk" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
kk0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk)
(VName
a_loc, VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName
aCopyLoc2Reg) <-
Bool
-> VName
-> (VName, VName, PrimType, SubExp, VName, Stm GPU, VName)
-> Builder
GPU
(VName,
VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName)
copyGlb2ShMem Bool
False VName
kk (VName
gtid_y, VName
iii, PrimType
map_t1, SubExp
height_A, VName
inp_A, Stm GPU
load_A, VName
a_loc_init')
(VName
b_loc, VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName
bCopyLoc2Reg) <-
Bool
-> VName
-> (VName, VName, PrimType, SubExp, VName, Stm GPU, VName)
-> Builder
GPU
(VName,
VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName)
copyGlb2ShMem Bool
True VName
kk (VName
gtid_x, VName
jjj, PrimType
map_t2, SubExp
width_B, VName
inp_B, Stm GPU
load_B, VName
b_loc_init')
VName
thd_acc <- VName
-> VName
-> (VName
-> VName -> VName -> BuilderT GPU (State VNameSource) VName)
-> (VName
-> VName -> VName -> BuilderT GPU (State VNameSource) VName)
-> Bool
-> BuilderT GPU (State VNameSource) VName
mkRedomapOneTileBody VName
kk VName
thd_res_merge VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName
aCopyLoc2Reg VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName
bCopyLoc2Reg Bool
True
[VName] -> Builder GPU [VName]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
thd_acc, VName
a_loc, VName
b_loc]
where
mk_ik :: Bool
-> Bool
-> (VName, VName)
-> (VName, VName)
-> BuilderT
GPU (State VNameSource) (VName, VName, TPrimExp Int64 VName)
mk_ik Bool
is_B Bool
is_coal (VName
thd_y, VName
thd_x) (VName
i0, VName
k0)
| Bool
is_coal = do
let (SubExp
t_par, SubExp
t_seq) = (SubExp
tx, SubExp
tk)
VName
k <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"k" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
thd_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_par)
VName
i <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"i" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
thd_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_par)
let pad_term :: TPrimExp Int64 VName
pad_term = if Bool
is_B then SubExp -> TPrimExp Int64 VName
pe64 SubExp
se1 else SubExp -> TPrimExp Int64 VName
pe64 SubExp
se0
let e :: TPrimExp Int64 VName
e = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* (SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_seq TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
pad_term)
(VName, VName, TPrimExp Int64 VName)
-> BuilderT
GPU (State VNameSource) (VName, VName, TPrimExp Int64 VName)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
i, VName
k, TPrimExp Int64 VName
e)
mk_ik Bool
_ Bool
_ (VName
thd_y, VName
thd_x) (VName
i0, VName
k0) = do
let (SubExp
t_par, SubExp
tr_par) = (SubExp
tx, SubExp
tx_rx)
VName
k <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"k" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
thd_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_par)
VName
i <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"i" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
thd_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i0 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_par)
let e :: TPrimExp Int64 VName
e = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tr_par
(VName, VName, TPrimExp Int64 VName)
-> BuilderT
GPU (State VNameSource) (VName, VName, TPrimExp Int64 VName)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
i, VName
k, TPrimExp Int64 VName
e)
mkCompLoopRxRy :: Bool
-> VName
-> (VName -> VName -> BuilderT GPU (State VNameSource) VName,
VName -> VName -> BuilderT GPU (State VNameSource) VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) (Body GPU)
mkCompLoopRxRy Bool
fits_ij VName
css_init (VName -> VName -> BuilderT GPU (State VNameSource) VName
a_idx_fn, VName -> VName -> BuilderT GPU (State VNameSource) VName
b_idx_fn) (VName
ltid_y, VName
ltid_x) = do
VName
css <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
ry [VName
css_init] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
css_merge] -> do
VName
css <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rx [VName
css_merge] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
j [VName
css_merge'] ->
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
[SubExp] -> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> BuilderT GPU (State VNameSource) (Body GPU))
-> (Exp GPU -> BuilderT GPU (State VNameSource) [SubExp])
-> Exp GPU
-> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"foo")
(Exp GPU -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
if Bool
fits_ij
then TPrimExp Bool VName
forall v. TPrimExp Bool v
true
else
(VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
iii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
height_A)
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jjj TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
j TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
width_B)
)
( do
VName
a <- VName -> VName -> BuilderT GPU (State VNameSource) VName
a_idx_fn VName
ltid_y VName
i
VName
b <- VName -> VName -> BuilderT GPU (State VNameSource) VName
b_idx_fn VName
ltid_x VName
j
VName
c <- [Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"c" VName
css_merge' [VName
i, VName
j]
Lambda GPU
map_lam' <- Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam
Lambda GPU
red_lam' <- Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
red_lam
let map_inp_reg :: [VName]
map_inp_reg = if [Int]
var_dims [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0, Int
1] then [VName
a, VName
b] else [VName
b, VName
a]
[SubExpRes]
map_res <- Lambda (Rep (BuilderT GPU (State VNameSource)))
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep (BuilderT GPU (State VNameSource)))
Lambda GPU
map_lam' ((VName -> BuilderT GPU (State VNameSource) (Exp GPU))
-> [VName] -> [BuilderT GPU (State VNameSource) (Exp GPU)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
SubExp -> BuilderT GPU (State VNameSource) (Exp GPU)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> BuilderT GPU (State VNameSource) (Exp GPU))
-> (VName -> SubExp)
-> VName
-> BuilderT GPU (State VNameSource) (Exp GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
map_inp_reg)
~[SubExpRes
red_res] <- Lambda (Rep (BuilderT GPU (State VNameSource)))
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep (BuilderT GPU (State VNameSource)))
Lambda GPU
red_lam' ((SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp]
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))])
-> [SubExp]
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
c SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
map_res)
VName
css <- [Char]
-> VName
-> [VName]
-> SubExp
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update [Char]
"css" VName
css_merge' [VName
i, VName
j] (SubExpRes -> SubExp
resSubExp SubExpRes
red_res)
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css]
)
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css_merge'])
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css]
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css]
mkRedomapOneTileBody :: VName
-> VName
-> (VName
-> VName -> VName -> BuilderT GPU (State VNameSource) VName)
-> (VName
-> VName -> VName -> BuilderT GPU (State VNameSource) VName)
-> Bool
-> BuilderT GPU (State VNameSource) VName
mkRedomapOneTileBody VName
kk VName
css_merge VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName
a_idx_fn VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName
b_idx_fn Bool
fits_ij = do
[VName]
redomap_res <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
segMap2D [Char]
"redomap_res" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName])
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
\(VName
ltid_y, VName
ltid_x) -> do
VName
css_init <- [Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"css_init" VName
css_merge [VName
ltid_y, VName
ltid_x]
VName
css <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
tk [VName
css_init] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
k [VName
acc_merge] ->
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
[SubExp] -> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> BuilderT GPU (State VNameSource) (Body GPU))
-> (Exp GPU -> BuilderT GPU (State VNameSource) [SubExp])
-> Exp GPU
-> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"foo")
(Exp GPU -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
if Bool
epilogue
then VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
kk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim
else TPrimExp Bool VName
forall v. TPrimExp Bool v
true
)
(Bool
-> VName
-> (VName -> VName -> BuilderT GPU (State VNameSource) VName,
VName -> VName -> BuilderT GPU (State VNameSource) VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) (Body GPU)
mkCompLoopRxRy Bool
fits_ij VName
acc_merge (VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName
a_idx_fn VName
k, VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName
b_idx_fn VName
k) (VName
ltid_y, VName
ltid_x))
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
acc_merge])
[SubExpRes] -> BuilderT GPU (State VNameSource) [SubExpRes]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
css]
VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> BuilderT GPU (State VNameSource) VName)
-> VName -> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
redomap_res
copyGlb2ShMem ::
Bool ->
VName ->
(VName, VName, PrimType, SubExp, VName, Stm GPU, VName) ->
Builder GPU (VName, VName -> VName -> VName -> Builder GPU VName)
copyGlb2ShMem :: Bool
-> VName
-> (VName, VName, PrimType, SubExp, VName, Stm GPU, VName)
-> Builder
GPU
(VName,
VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName)
copyGlb2ShMem Bool
is_B VName
kk (VName
gtid, VName
ii, PrimType
ptp_X_el, SubExp
parlen_X, VName
inp_X, Stm GPU
load_X, VName
x_loc_init') = do
let (SubExp
t_par, SubExp
r_par, SubExp
tseq_div_tpar) = (SubExp
tx, SubExp
rx, SubExp
tk_div_tx)
is_inner_coal :: Bool
is_inner_coal = Env -> VName -> Stm GPU -> Bool
isInnerCoal Env
env VName
inp_X Stm GPU
load_X
str_A :: [Char]
str_A = VName -> [Char]
baseString VName
inp_X
VName
x_loc <-
[Char]
-> VName
-> [SubExp]
-> (SubExp, SubExp)
-> ([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp))
-> BuilderT GPU (State VNameSource) VName
segScatter2D ([Char]
str_A [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_glb2loc") VName
x_loc_init' [SubExp
r_par, SubExp
tseq_div_tpar] (SubExp
t_par, SubExp
t_par) (([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp))
-> BuilderT GPU (State VNameSource) VName)
-> ([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
Bool -> [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
scatterFun Bool
is_inner_coal
(VName,
VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName)
-> Builder
GPU
(VName,
VName -> VName -> VName -> BuilderT GPU (State VNameSource) VName)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
x_loc, Bool
-> [Char]
-> VName
-> VName
-> VName
-> VName
-> BuilderT GPU (State VNameSource) VName
indexLocMem Bool
is_inner_coal [Char]
str_A VName
x_loc)
where
indexLocMem ::
Bool ->
String ->
VName ->
VName ->
VName ->
VName ->
Builder GPU VName
indexLocMem :: Bool
-> [Char]
-> VName
-> VName
-> VName
-> VName
-> BuilderT GPU (State VNameSource) VName
indexLocMem Bool
is_inner_coal [Char]
str_A VName
x_loc VName
k VName
ltid_yx VName
ij = do
let (SubExp
r_par, SubExp
t_seq, SubExp
tr_par) = (SubExp
rx, SubExp
tk, SubExp
tx_rx)
let pad_term :: TPrimExp Int64 VName
pad_term = if Bool
is_B then SubExp -> TPrimExp Int64 VName
pe64 SubExp
se1 else SubExp -> TPrimExp Int64 VName
pe64 SubExp
se0
VName
x_loc_ind_32 <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp ([Char]
str_A [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_loc_ind_64")
(Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp
( if Bool
is_inner_coal
then VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_yx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
r_par TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ij) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* (SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_seq TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
pad_term)
else VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ij TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_yx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
r_par TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tr_par
)
[Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index ([Char]
str_A [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_loc_elem") VName
x_loc [VName
x_loc_ind_32]
scatterFun ::
Bool ->
[VName] ->
(VName, VName) ->
Builder GPU (SubExp, SubExp)
scatterFun :: Bool -> [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
scatterFun Bool
is_inner_coal [VName
i0, VName
k0] (VName
thd_y, VName
thd_x) = do
let str_A :: [Char]
str_A = VName -> [Char]
baseString VName
inp_X
t_seq :: SubExp
t_seq = SubExp
tk
(VName
i, VName
k, TPrimExp Int64 VName
epx_loc_fi) <- Bool
-> Bool
-> (VName, VName)
-> (VName, VName)
-> BuilderT
GPU (State VNameSource) (VName, VName, TPrimExp Int64 VName)
mk_ik Bool
is_B Bool
is_inner_coal (VName
thd_y, VName
thd_x) (VName
i0, VName
k0)
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
VName
a_seqdim_idx <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp ([Char]
str_A [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_seqdim_idx") (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
kk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k)
SubExp
a_elem <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp ([Char]
str_A [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_elem")
(Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
parlen_X
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. if Bool
epilogue
then VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
a_seqdim_idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim
else TPrimExp Bool VName
forall v. TPrimExp Bool v
true
)
( do
Stm (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep (BuilderT GPU (State VNameSource)))
Stm GPU
load_X
VName
res <- [Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"A_elem" VName
inp_X [VName
a_seqdim_idx]
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
res]
)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [TypeBase Shape NoUniqueness
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
TypeBase Shape NoUniqueness -> m (Exp (Rep m))
eBlank (TypeBase Shape NoUniqueness
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TypeBase Shape NoUniqueness
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ptp_X_el])
SubExp
a_loc_ind <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp ([Char]
str_A [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_loc_ind")
(Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
k TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
t_seq)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
epx_loc_fi])
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)])
(SubExp, SubExp) -> Builder GPU (SubExp, SubExp)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
a_elem, SubExp
a_loc_ind)
scatterFun Bool
_ [VName]
_ (VName, VName)
_ = do
[Char] -> Builder GPU (SubExp, SubExp)
forall a. HasCallStack => [Char] -> a
error [Char]
"Function scatterFun in Shared.hs: 2nd arg should be an array with 2 elements!"
mmBlkRegTiling :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTiling :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTiling Env
env Stm GPU
stm = do
Maybe (Stms GPU, Stm GPU)
res <- Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingAcc Env
env Stm GPU
stm
case Maybe (Stms GPU, Stm GPU)
res of
Maybe (Stms GPU, Stm GPU)
Nothing -> Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingNrm Env
env Stm GPU
stm
Maybe (Stms GPU, Stm GPU)
_ -> Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
res
mmBlkRegTilingAcc :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingAcc :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingAcc Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegThread {} SegSpace
seg_space [TypeBase Shape NoUniqueness]
ts KernelBody GPU
old_kbody))))
| KernelBody () Stms GPU
kstms [Returns ResultManifest
ResultMaySimplify Certs
cs (Var VName
res_nm)] <- KernelBody GPU
old_kbody,
Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty,
[TypeBase Shape NoUniqueness
res_tp] <- [TypeBase Shape NoUniqueness]
ts,
TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc TypeBase Shape NoUniqueness
res_tp,
(VName
gtid_x, SubExp
width_B) : (VName
gtid_y, SubExp
height_A) : [(VName, SubExp)]
rem_outer_dims_rev <-
[(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
seg_space,
[(VName, SubExp)]
rem_outer_dims <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
rem_outer_dims_rev,
Just
( Stms GPU
code2',
(Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
SubExp
common_dim,
[Int]
var_dims,
(Lambda GPU
map_lam, Lambda GPU
red_lam, SubExp
red_ne, VName
redomap_orig_res, PrimType
red_t)
) <-
SegSpace
-> Stms GPU
-> Maybe
(Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType))
matchesBlkRegTile SegSpace
seg_space Stms GPU
kstms,
VName -> Stms GPU -> VName -> Bool
forall {rep}. VName -> Stms rep -> VName -> Bool
checkAccumulatesRedomapRes VName
res_nm Stms GPU
code2' VName
redomap_orig_res = do
let is_B_coal :: Bool
is_B_coal = Env -> VName -> Stm GPU -> Bool
isInnerCoal Env
env VName
inp_B Stm GPU
load_B
(Stm GPU
new_kernel, Stms GPU
host_stms) <- Builder GPU (Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU, Stms GPU))
-> Builder GPU (Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
(SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx, SubExp
ty_ry, SubExp
a_loc_sz, SubExp
b_loc_sz) <-
SubExp
-> SubExp
-> SubExp
-> Bool
-> Builder
GPU
(SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp, SubExp, SubExp)
mkTileMemSizes SubExp
height_A SubExp
width_B SubExp
common_dim Bool
is_B_coal
SubExp
rk <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rk" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
8
SubExp
tk_rk <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tk_rk" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rk)
SubExp
gridDim_t <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_t" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
common_dim SubExp
tk_rk
SubExp
gridDim_y <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_y" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
height_A SubExp
ty_ry
SubExp
gridDim_x <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_x" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
width_B SubExp
tx_rx
let gridxyt_pexp :: TPrimExp Int64 VName
gridxyt_pexp = SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_t
grid_pexp :: TPrimExp Int64 VName
grid_pexp =
(TPrimExp Int64 VName -> SubExp -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> [SubExp] -> TPrimExp Int64 VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 VName
x SubExp
d -> SubExp -> TPrimExp Int64 VName
pe64 SubExp
d TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
x) TPrimExp Int64 VName
gridxyt_pexp ([SubExp] -> TPrimExp Int64 VName)
-> [SubExp] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
rem_outer_dims_rev
(SubExp
grid_size, SubExp
tblock_size, SegLevel
segthd_lvl) <- SubExp
-> SubExp
-> TPrimExp Int64 VName
-> Builder GPU (SubExp, SubExp, SegLevel)
mkNewSegthdLvl SubExp
tx SubExp
ty TPrimExp Int64 VName
grid_pexp
(VName
gid_x, VName
gid_y, VName
gid_flat) <- Builder GPU (VName, VName, VName)
mkGidsXYF
VName
gid_t <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_t"
([KernelResult]
ret_seggroup, Stms GPU
stms_seggroup) <- Builder GPU [KernelResult]
-> BuilderT GPU (State VNameSource) ([KernelResult], Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU [KernelResult]
-> BuilderT GPU (State VNameSource) ([KernelResult], Stms GPU))
-> Builder GPU [KernelResult]
-> BuilderT GPU (State VNameSource) ([KernelResult], Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
VName
iii <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iii" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty_ry)
VName
jjj <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jjj" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx_rx)
VName
ttt <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ttt" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_t TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk_rk)
(VName
cssss, VName
a_loc_init, VName
b_loc_init) <-
(SubExp, SubExp, SubExp, SubExp, SubExp, SubExp)
-> (PrimType, PrimType, PrimType)
-> SegLevel
-> SubExp
-> Builder GPU (VName, VName, VName)
initRegShmem
(SubExp
rx, SubExp
tx, SubExp
ry, SubExp
ty, SubExp
a_loc_sz, SubExp
b_loc_sz)
(PrimType
map_t1, PrimType
map_t2, PrimType
red_t)
SegLevel
segthd_lvl
SubExp
red_ne
SubExp
elems_on_t <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"elems_on_t" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ttt)
SubExp
tiles_on_t <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tiles_on_t" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
elems_on_t SubExp
tk
VName
full_tiles <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"full_tiles" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int64) SubExp
rk SubExp
tiles_on_t
let ct_arg :: ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
ct_arg =
( (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx),
SegLevel
segthd_lvl,
[Int]
var_dims,
(VName
gtid_x, SubExp
width_B, VName
gtid_y, SubExp
height_A, SubExp
common_dim),
(VName
iii, VName
jjj),
(Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
(Lambda GPU
map_lam, Lambda GPU
red_lam)
)
[VName]
prologue_res_list <-
SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forLoop' (VName -> SubExp
Var VName
full_tiles) [VName
cssss, VName
a_loc_init, VName
b_loc_init] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName])
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
\VName
kk0 [VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge] -> do
VName
off_t <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"off_t" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
rk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_t TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
kk0)
[VName]
process_full_tiles <-
Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody Env
env ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
ct_arg VName
off_t (VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge) Bool
False
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
process_full_tiles
let VName
prologue_res : VName
a_loc_reuse : VName
b_loc_reuse : [VName]
_ = [VName]
prologue_res_list
[VName]
redomap_res_lst <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"redomap_res_if"
(Exp GPU -> Builder GPU [VName])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> Builder GPU [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
full_tiles
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. SubExp -> TPrimExp Int64 VName
pe64 SubExp
rk
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. SubExp -> TPrimExp Int64 VName
pe64 SubExp
common_dim
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
full_tiles TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ttt)
)
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
prologue_res_list)
( do
VName
off_t <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"off_t" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
rk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_t TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
full_tiles)
[VName]
process_sprs_tile <-
Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody Env
env ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
ct_arg VName
off_t (VName
prologue_res, VName
a_loc_reuse, VName
b_loc_reuse) Bool
True
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
process_sprs_tile
)
let VName
redomap_res : [VName]
_ = [VName]
redomap_res_lst
SegLevel
-> (VName, VName)
-> (VName, TypeBase Shape NoUniqueness)
-> (SubExp, SubExp, SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp, [(VName, SubExp)])
-> Stms GPU
-> Builder GPU [KernelResult]
mkEpilogueAccRes
SegLevel
segthd_lvl
(VName
redomap_orig_res, VName
redomap_res)
(VName
res_nm, TypeBase Shape NoUniqueness
res_tp)
(SubExp
ty, SubExp
tx, SubExp
ry, SubExp
rx)
(VName
iii, VName
jjj)
(VName
gtid_y, VName
gtid_x)
(SubExp
height_A, SubExp
width_B, [(VName, SubExp)]
rem_outer_dims)
Stms GPU
code2'
let grid :: KernelGrid
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
grid_size) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
tblock_size)
level' :: SegLevel
level' = SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
SegNoVirt (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)
space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)]
rem_outer_dims [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_t, SubExp
gridDim_t), (VName
gid_y, SubExp
gridDim_y), (VName
gid_x, SubExp
gridDim_x)])
kbody' :: KernelBody GPU
kbody' = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms_seggroup [KernelResult]
ret_seggroup
Stm GPU -> Builder GPU (Stm GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPU -> Builder GPU (Stm GPU))
-> Stm GPU -> Builder GPU (Stm GPU)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
level' SegSpace
space' [TypeBase Shape NoUniqueness]
ts KernelBody GPU
kbody'
Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU)))
-> Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a b. (a -> b) -> a -> b
$ (Stms GPU, Stm GPU) -> Maybe (Stms GPU, Stm GPU)
forall a. a -> Maybe a
Just (Stms GPU
host_stms, Stm GPU
new_kernel)
where
sameAccType :: VName -> TypeBase shape u -> Bool
sameAccType VName
acc_sglton (Acc VName
sglton Shape
_ [TypeBase Shape NoUniqueness]
_ u
_) =
VName
acc_sglton VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
sglton
sameAccType VName
_ TypeBase shape u
_ = Bool
False
getAccumFV :: TypeBase Shape NoUniqueness
-> BuilderT GPU (State VNameSource) VName
getAccumFV (Acc VName
singleton Shape
_shp [TypeBase Shape NoUniqueness
_eltp] NoUniqueness
_) = do
let fvs :: [VName]
fvs = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ KernelBody GPU -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody GPU
old_kbody
[TypeBase Shape NoUniqueness]
tps <- Scope GPU
-> BuilderT GPU (State VNameSource) [TypeBase Shape NoUniqueness]
-> BuilderT GPU (State VNameSource) [TypeBase Shape NoUniqueness]
forall a.
Scope GPU
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
seg_space) (BuilderT GPU (State VNameSource) [TypeBase Shape NoUniqueness]
-> BuilderT GPU (State VNameSource) [TypeBase Shape NoUniqueness])
-> BuilderT GPU (State VNameSource) [TypeBase Shape NoUniqueness]
-> BuilderT GPU (State VNameSource) [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ do
(VName
-> BuilderT GPU (State VNameSource) (TypeBase Shape NoUniqueness))
-> [VName]
-> BuilderT GPU (State VNameSource) [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName
-> BuilderT GPU (State VNameSource) (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
fvs
let ([VName]
acc_0s, [TypeBase Shape NoUniqueness]
_) = [(VName, TypeBase Shape NoUniqueness)]
-> ([VName], [TypeBase Shape NoUniqueness])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, TypeBase Shape NoUniqueness)]
-> ([VName], [TypeBase Shape NoUniqueness]))
-> [(VName, TypeBase Shape NoUniqueness)]
-> ([VName], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ ((VName, TypeBase Shape NoUniqueness) -> Bool)
-> [(VName, TypeBase Shape NoUniqueness)]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> TypeBase Shape NoUniqueness -> Bool
forall {shape} {u}. VName -> TypeBase shape u -> Bool
sameAccType VName
singleton (TypeBase Shape NoUniqueness -> Bool)
-> ((VName, TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness)
-> (VName, TypeBase Shape NoUniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall a b. (a, b) -> b
snd) ([(VName, TypeBase Shape NoUniqueness)]
-> [(VName, TypeBase Shape NoUniqueness)])
-> [(VName, TypeBase Shape NoUniqueness)]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [VName]
-> [TypeBase Shape NoUniqueness]
-> [(VName, TypeBase Shape NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
fvs [TypeBase Shape NoUniqueness]
tps
case [VName]
acc_0s of
[VName
acc_0] -> VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
acc_0
[VName]
_ -> [Char] -> BuilderT GPU (State VNameSource) VName
forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible case reached when treating accumulators!"
getAccumFV TypeBase Shape NoUniqueness
tp = [Char] -> BuilderT GPU (State VNameSource) VName
forall a. HasCallStack => [Char] -> a
error ([Char]
"Should be an accumulator type at this point, given: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [Char]
forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
tp)
checkAccumulatesRedomapRes :: VName -> Stms rep -> VName -> Bool
checkAccumulatesRedomapRes VName
res_nm Stms rep
acc_code VName
redomap_orig_res = do
(Bool -> Stm rep -> Bool) -> Bool -> [Stm rep] -> Bool
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Bool -> Stm rep -> Bool
getAccumStm Bool
False ([Stm rep] -> Bool) -> [Stm rep] -> Bool
forall a b. (a -> b) -> a -> b
$ [Stm rep] -> [Stm rep]
forall a. [a] -> [a]
reverse ([Stm rep] -> [Stm rep]) -> [Stm rep] -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
acc_code
where
getAccumStm :: Bool -> Stm rep -> Bool
getAccumStm Bool
True Stm rep
_ = Bool
True
getAccumStm Bool
False (Let (Pat [PatElem (LetDec rep)
pat_el]) StmAux (ExpDec rep)
_aux (BasicOp (UpdateAcc Safety
_ VName
_acc_nm [SubExp]
_ind [SubExp]
vals)))
| [SubExp
v] <- [SubExp]
vals,
PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pat_el VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
res_nm =
SubExp
v SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
redomap_orig_res
getAccumStm Bool
False Stm rep
_ = Bool
False
mkEpilogueAccRes :: SegLevel
-> (VName, VName)
-> (VName, TypeBase Shape NoUniqueness)
-> (SubExp, SubExp, SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp, [(VName, SubExp)])
-> Stms GPU
-> Builder GPU [KernelResult]
mkEpilogueAccRes
SegLevel
segthd_lvl
(VName
redomap_orig_res, VName
redomap_res)
(VName
res_nm, TypeBase Shape NoUniqueness
res_tp)
(SubExp
ty, SubExp
tx, SubExp
ry, SubExp
rx)
(VName
iii, VName
jjj)
(VName
gtid_y, VName
gtid_x)
(SubExp
height_A, SubExp
width_B, [(VName, SubExp)]
_rem_outer_dims)
Stms GPU
code2' = do
VName
rss_init <- TypeBase Shape NoUniqueness
-> BuilderT GPU (State VNameSource) VName
getAccumFV TypeBase Shape NoUniqueness
res_tp
[VName]
rssss_list <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
segMap2D [Char]
"rssss" SegLevel
segthd_lvl ResultManifest
ResultMaySimplify (SubExp
ty, SubExp
tx) (((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName])
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_y, VName
ltid_x) -> do
(VName
css, VName
ii, VName
jj) <- (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName, VName)
-> Builder GPU (VName, VName, VName)
getThdRedomapRes (SubExp
rx, SubExp
ry) (VName
ltid_x, VName
ltid_y) (VName
iii, VName
jjj, VName
redomap_res)
VName
rss <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
ry [VName
rss_init] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
rss_merge] -> do
VName
rss' <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rx [VName
rss_merge] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
j [VName
rss_merge'] -> do
(VName, VName)
-> (VName, VName, VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
prereqAddCode2 (VName
gtid_x, VName
gtid_y) (VName
ii, VName
i, VName
jj, VName
j) (VName
css, VName
redomap_orig_res)
let code2_subs :: Stms GPU
code2_subs = Map VName VName -> Stms GPU -> Stms GPU
forall a. Substitute a => Map VName VName -> a -> a
substituteNames (VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
rss_init VName
rss_merge') Stms GPU
code2'
SubExp
res_el <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"res_elem"
(Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
height_A
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
width_B
)
( do
Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
code2_subs
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
res_nm]
)
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
rss_merge'])
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
res_el]
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
rss']
[SubExpRes] -> BuilderT GPU (State VNameSource) [SubExpRes]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
rss]
let VName
epilogue_res_acc : [VName]
_ = [VName]
rssss_list
[KernelResult] -> Builder GPU [KernelResult]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify ([VName] -> Certs
Certs []) (SubExp -> KernelResult) -> SubExp -> KernelResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
epilogue_res_acc]
mmBlkRegTilingAcc Env
_ Stm GPU
_ = Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
forall a. Maybe a
Nothing
mmBlkRegTilingNrm :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingNrm :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
mmBlkRegTilingNrm Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegThread {} SegSpace
seg_space [TypeBase Shape NoUniqueness]
ts KernelBody GPU
old_kbody))))
| KernelBody () Stms GPU
kstms [Returns ResultManifest
ResultMaySimplify Certs
cs (Var VName
res_nm)] <- KernelBody GPU
old_kbody,
Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty,
[TypeBase Shape NoUniqueness
res_tp] <- [TypeBase Shape NoUniqueness]
ts,
TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
res_tp,
(VName
gtid_x, SubExp
width_B) : (VName
gtid_y, SubExp
height_A) : [(VName, SubExp)]
rem_outer_dims_rev <-
[(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
seg_space,
[(VName, SubExp)]
rem_outer_dims <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
rem_outer_dims_rev,
Just
( Stms GPU
code2',
(Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
SubExp
common_dim,
[Int]
var_dims,
(Lambda GPU
map_lam, Lambda GPU
red_lam, SubExp
red_ne, VName
redomap_orig_res, PrimType
red_t)
) <-
SegSpace
-> Stms GPU
-> Maybe
(Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType))
matchesBlkRegTile SegSpace
seg_space Stms GPU
kstms = do
let is_B_coal :: Bool
is_B_coal = Env -> VName -> Stm GPU -> Bool
isInnerCoal Env
env VName
inp_B Stm GPU
load_B
(Stm GPU
new_kernel, Stms GPU
host_stms) <- Builder GPU (Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU, Stms GPU))
-> Builder GPU (Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
(SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx, SubExp
ty_ry, SubExp
a_loc_sz, SubExp
b_loc_sz) <-
SubExp
-> SubExp
-> SubExp
-> Bool
-> Builder
GPU
(SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp, SubExp, SubExp)
mkTileMemSizes SubExp
height_A SubExp
width_B SubExp
common_dim Bool
is_B_coal
SubExp
gridDim_x <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_x" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
width_B SubExp
tx_rx
SubExp
gridDim_y <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_y" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
height_A SubExp
ty_ry
let gridxy_pexp :: TPrimExp Int64 VName
gridxy_pexp = SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_x
let grid_pexp :: TPrimExp Int64 VName
grid_pexp =
(TPrimExp Int64 VName -> SubExp -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> [SubExp] -> TPrimExp Int64 VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 VName
x SubExp
d -> SubExp -> TPrimExp Int64 VName
pe64 SubExp
d TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
x) TPrimExp Int64 VName
gridxy_pexp ([SubExp] -> TPrimExp Int64 VName)
-> [SubExp] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
rem_outer_dims_rev
(SubExp
grid_size, SubExp
tblock_size, SegLevel
segthd_lvl) <- SubExp
-> SubExp
-> TPrimExp Int64 VName
-> Builder GPU (SubExp, SubExp, SegLevel)
mkNewSegthdLvl SubExp
tx SubExp
ty TPrimExp Int64 VName
grid_pexp
(VName
gid_x, VName
gid_y, VName
gid_flat) <- Builder GPU (VName, VName, VName)
mkGidsXYF
([KernelResult]
ret_seggroup, Stms GPU
stms_seggroup) <- Builder GPU [KernelResult]
-> BuilderT GPU (State VNameSource) ([KernelResult], Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU [KernelResult]
-> BuilderT GPU (State VNameSource) ([KernelResult], Stms GPU))
-> Builder GPU [KernelResult]
-> BuilderT GPU (State VNameSource) ([KernelResult], Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
VName
iii <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iii" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty_ry)
VName
jjj <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jjj" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx_rx)
(VName
cssss, VName
a_loc_init, VName
b_loc_init) <-
(SubExp, SubExp, SubExp, SubExp, SubExp, SubExp)
-> (PrimType, PrimType, PrimType)
-> SegLevel
-> SubExp
-> Builder GPU (VName, VName, VName)
initRegShmem
(SubExp
rx, SubExp
tx, SubExp
ry, SubExp
ty, SubExp
a_loc_sz, SubExp
b_loc_sz)
(PrimType
map_t1, PrimType
map_t2, PrimType
red_t)
SegLevel
segthd_lvl
SubExp
red_ne
VName
full_tiles <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"full_tiles" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
common_dim SubExp
tk
let ct_arg :: ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
ct_arg =
( (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx),
SegLevel
segthd_lvl,
[Int]
var_dims,
(VName
gtid_x, SubExp
width_B, VName
gtid_y, SubExp
height_A, SubExp
common_dim),
(VName
iii, VName
jjj),
(Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
(Lambda GPU
map_lam, Lambda GPU
red_lam)
)
[VName]
prologue_res_list <-
SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forLoop' (VName -> SubExp
Var VName
full_tiles) [VName
cssss, VName
a_loc_init, VName
b_loc_init] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName])
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
\VName
kk0 [VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge] -> do
[VName]
process_full_tiles <-
Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody Env
env ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
ct_arg VName
kk0 (VName
thd_res_merge, VName
a_loc_merge, VName
b_loc_merge) Bool
False
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
process_full_tiles
let VName
prologue_res : VName
a_loc_reuse : VName
b_loc_reuse : [VName]
_ = [VName]
prologue_res_list
[VName]
epilogue_res_list <- Env
-> ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
-> VName
-> (VName, VName, VName)
-> Bool
-> Builder GPU [VName]
kkLoopBody Env
env ((SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp),
SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp),
(VName, VName),
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
(Lambda GPU, Lambda GPU))
ct_arg VName
full_tiles (VName
prologue_res, VName
a_loc_reuse, VName
b_loc_reuse) Bool
True
let VName
redomap_res : [VName]
_ = [VName]
epilogue_res_list
SegLevel
-> (VName, VName)
-> (VName, TypeBase Shape NoUniqueness)
-> (SubExp, SubExp, SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp, [(VName, SubExp)])
-> Stms GPU
-> Builder GPU [KernelResult]
forall {a}.
SegLevel
-> (VName, VName)
-> (VName, TypeBase Shape NoUniqueness)
-> (SubExp, SubExp, SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp, [(a, SubExp)])
-> Stms GPU
-> Builder GPU [KernelResult]
mkEpiloguePrimRes
SegLevel
segthd_lvl
(VName
redomap_orig_res, VName
redomap_res)
(VName
res_nm, TypeBase Shape NoUniqueness
res_tp)
(SubExp
ty, SubExp
tx, SubExp
ry, SubExp
rx)
(VName
iii, VName
jjj)
(VName
gtid_y, VName
gtid_x)
(SubExp
height_A, SubExp
width_B, [(VName, SubExp)]
rem_outer_dims)
Stms GPU
code2'
let grid :: KernelGrid
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
grid_size) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
tblock_size)
level' :: SegLevel
level' = SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
SegNoVirt (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)
space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)]
rem_outer_dims [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_y, SubExp
gridDim_y), (VName
gid_x, SubExp
gridDim_x)])
kbody' :: KernelBody GPU
kbody' = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms_seggroup [KernelResult]
ret_seggroup
Stm GPU -> Builder GPU (Stm GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPU -> Builder GPU (Stm GPU))
-> Stm GPU -> Builder GPU (Stm GPU)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
level' SegSpace
space' [TypeBase Shape NoUniqueness]
ts KernelBody GPU
kbody'
Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU)))
-> Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a b. (a -> b) -> a -> b
$ (Stms GPU, Stm GPU) -> Maybe (Stms GPU, Stm GPU)
forall a. a -> Maybe a
Just (Stms GPU
host_stms, Stm GPU
new_kernel)
where
mkEpiloguePrimRes :: SegLevel
-> (VName, VName)
-> (VName, TypeBase Shape NoUniqueness)
-> (SubExp, SubExp, SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp, [(a, SubExp)])
-> Stms GPU
-> Builder GPU [KernelResult]
mkEpiloguePrimRes
SegLevel
segthd_lvl
(VName
redomap_orig_res, VName
redomap_res)
(VName
res_nm, TypeBase Shape NoUniqueness
res_tp)
(SubExp
ty, SubExp
tx, SubExp
ry, SubExp
rx)
(VName
iii, VName
jjj)
(VName
gtid_y, VName
gtid_x)
(SubExp
height_A, SubExp
width_B, [(a, SubExp)]
rem_outer_dims)
Stms GPU
code2' = do
VName
epilogue_res <-
if VName
redomap_orig_res VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
res_nm
then VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
redomap_res
else do
[VName]
rssss_list <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
segMap2D [Char]
"rssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName])
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_y, VName
ltid_x) -> do
VName
rss_init <- [Char]
-> PrimType -> [SubExp] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"rss_init" (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
res_tp) [SubExp
ry, SubExp
rx]
(VName
css, VName
ii, VName
jj) <- (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName, VName)
-> Builder GPU (VName, VName, VName)
getThdRedomapRes (SubExp
rx, SubExp
ry) (VName
ltid_x, VName
ltid_y) (VName
iii, VName
jjj, VName
redomap_res)
VName
rss <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
ry [VName
rss_init] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
rss_merge] -> do
VName
rss' <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rx [VName
rss_merge] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
j [VName
rss_merge'] -> do
(VName, VName)
-> (VName, VName, VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
prereqAddCode2 (VName
gtid_x, VName
gtid_y) (VName
ii, VName
i, VName
jj, VName
j) (VName
css, VName
redomap_orig_res)
SubExp
res_el <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"res_elem"
(Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
height_A
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
width_B
)
( do
Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
code2'
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
res_nm]
)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [TypeBase Shape NoUniqueness
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
TypeBase Shape NoUniqueness -> m (Exp (Rep m))
eBlank TypeBase Shape NoUniqueness
res_tp])
VName
rss'' <- [Char]
-> VName
-> [VName]
-> SubExp
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update [Char]
"rss" VName
rss_merge' [VName
i, VName
j] SubExp
res_el
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
rss'']
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
rss']
[SubExpRes] -> BuilderT GPU (State VNameSource) [SubExpRes]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
rss]
let VName
rssss : [VName]
_ = [VName]
rssss_list
VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
rssss
let regtile_ret_dims :: [(SubExp, SubExp, SubExp)]
regtile_ret_dims =
((a, SubExp) -> (SubExp, SubExp, SubExp))
-> [(a, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
_, SubExp
sz) -> (SubExp
sz, SubExp
se1, SubExp
se1)) [(a, SubExp)]
rem_outer_dims
[(SubExp, SubExp, SubExp)]
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(SubExp
height_A, SubExp
ty, SubExp
ry), (SubExp
width_B, SubExp
tx, SubExp
rx)]
VName
epilogue_res' <-
if [(a, SubExp)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(a, SubExp)]
rem_outer_dims
then VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
epilogue_res
else do
TypeBase Shape NoUniqueness
epilogue_t <- VName
-> BuilderT GPU (State VNameSource) (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
epilogue_res
let ([SubExp]
block_dims, [SubExp]
rest_dims) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
2 ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
epilogue_t
ones :: [SubExp]
ones = ((a, SubExp) -> SubExp) -> [(a, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> (a, SubExp) -> SubExp
forall a b. a -> b -> a
const (SubExp -> (a, SubExp) -> SubExp)
-> SubExp -> (a, SubExp) -> SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [(a, SubExp)]
rem_outer_dims
new_shape :: Shape
new_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]
ones, [SubExp]
block_dims, [SubExp]
ones, [SubExp]
rest_dims]
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"res_reshaped" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> (BasicOp -> Exp GPU)
-> BasicOp
-> BuilderT GPU (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT GPU (State VNameSource) VName)
-> BasicOp -> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary Shape
new_shape VName
epilogue_res
[KernelResult] -> Builder GPU [KernelResult]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns Certs
forall a. Monoid a => a
mempty [(SubExp, SubExp, SubExp)]
regtile_ret_dims VName
epilogue_res']
mmBlkRegTilingNrm Env
_ Stm GPU
_ = Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
forall a. Maybe a
Nothing
matchesBlkRegTile ::
SegSpace ->
Stms GPU ->
Maybe
( Stms GPU,
(Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
SubExp,
[Int],
(Lambda GPU, Lambda GPU, SubExp, VName, PrimType)
)
matchesBlkRegTile :: SegSpace
-> Stms GPU
-> Maybe
(Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType))
matchesBlkRegTile SegSpace
seg_space Stms GPU
kstms
|
Map VName Names
initial_variance <- (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> Map VName Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
seg_space,
Map VName Names
variance <- Map VName Names -> Stms GPU -> Map VName Names
varianceInStms Map VName Names
initial_variance Stms GPU
kstms,
(Stms GPU
code1, Just Stm GPU
screma_stmt, Stms GPU
code2) <- Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeStreamCode Stms GPU
kstms,
Let Pat (LetDec GPU)
pat_redomap StmAux (ExpDec GPU)
_ (Op Op GPU
_) <- Stm GPU
screma_stmt,
Just (SubExp
common_dim, [VName]
arrs, (Commutativity
_, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam)) <- Stm GPU
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
screma_stmt,
[VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2,
[SubExp
red_ne] <- [SubExp]
red_nes,
[TypeBase Shape NoUniqueness
map_t1t, TypeBase Shape NoUniqueness
map_t2t] <- (Param (TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Param dec -> dec
paramDec ([Param (TypeBase Shape NoUniqueness)]
-> [TypeBase Shape NoUniqueness])
-> [Param (TypeBase Shape NoUniqueness)]
-> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam,
[TypeBase Shape NoUniqueness
red_t1, TypeBase Shape NoUniqueness
_] <- (Param (TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Param dec -> dec
paramDec ([Param (TypeBase Shape NoUniqueness)]
-> [TypeBase Shape NoUniqueness])
-> [Param (TypeBase Shape NoUniqueness)]
-> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam,
TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
map_t1t Bool -> Bool -> Bool
&& TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
map_t2t Bool -> Bool -> Bool
&& TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
red_t1,
PrimType
map_t1_0 <- TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
map_t1t,
PrimType
map_t2_0 <- TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
map_t2t,
Just [Int]
var_dims <- Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo1of2InnerDims Names
forall a. Monoid a => a
mempty SegSpace
seg_space Map VName Names
variance [VName]
arrs,
[VName
redomap_orig_res] <- Pat (TypeBase Shape NoUniqueness) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (TypeBase Shape NoUniqueness)
Pat (LetDec GPU)
pat_redomap,
Just Names
res_red_var <- VName -> Map VName Names -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
redomap_orig_res Map VName Names
variance,
Just (Stms GPU
code2'', Map VName (Stm GPU)
tab_inv_stm) <-
(Maybe (Stms GPU, Map VName (Stm GPU))
-> Stm GPU -> Maybe (Stms GPU, Map VName (Stm GPU)))
-> Maybe (Stms GPU, Map VName (Stm GPU))
-> Stms GPU
-> Maybe (Stms GPU, Map VName (Stm GPU))
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
(Names
-> Names
-> Maybe (Stms GPU, Map VName (Stm GPU))
-> Stm GPU
-> Maybe (Stms GPU, Map VName (Stm GPU))
processIndirections ([VName] -> Names
namesFromList [VName]
arrs) Names
res_red_var)
((Stms GPU, Map VName (Stm GPU))
-> Maybe (Stms GPU, Map VName (Stm GPU))
forall a. a -> Maybe a
Just (Stms GPU
forall a. Seq a
Seq.empty, Map VName (Stm GPU)
forall k a. Map k a
M.empty))
Stms GPU
code1,
[Stm GPU]
tmp_stms <- (VName -> Maybe (Stm GPU)) -> [VName] -> [Stm GPU]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName (Stm GPU) -> Maybe (Stm GPU)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Stm GPU)
tab_inv_stm) [VName]
arrs,
[Stm GPU] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Stm GPU]
tmp_stms Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs =
let zip_AB :: [(Stm GPU, VName, PrimType)]
zip_AB = [Stm GPU] -> [VName] -> [PrimType] -> [(Stm GPU, VName, PrimType)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Stm GPU]
tmp_stms [VName]
arrs [PrimType
map_t1_0, PrimType
map_t2_0]
[(Stm GPU
load_A, VName
inp_A, PrimType
map_t1), (Stm GPU
load_B, VName
inp_B, PrimType
map_t2)] =
if [Int]
var_dims [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0, Int
1]
then [(Stm GPU, VName, PrimType)]
zip_AB
else [(Stm GPU, VName, PrimType)] -> [(Stm GPU, VName, PrimType)]
forall a. [a] -> [a]
reverse [(Stm GPU, VName, PrimType)]
zip_AB
code2' :: Stms GPU
code2' = Stms GPU
code2'' Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
code2
in (Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType))
-> Maybe
(Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType))
forall a. a -> Maybe a
Just
( Stms GPU
code2',
(Stm GPU
load_A, VName
inp_A, PrimType
map_t1, Stm GPU
load_B, VName
inp_B, PrimType
map_t2),
SubExp
common_dim,
[Int]
var_dims,
(Lambda GPU
map_lam, Lambda GPU
red_lam, SubExp
red_ne, VName
redomap_orig_res, TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
red_t1)
)
matchesBlkRegTile SegSpace
_ Stms GPU
_ = Maybe
(Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType),
SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType))
forall a. Maybe a
Nothing
ceilDiv :: (MonadBuilder m) => SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv :: forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
x SubExp
y = Exp (Rep m) -> m (Exp (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
x SubExp
y
mkTileMemSizes ::
SubExp ->
SubExp ->
SubExp ->
Bool ->
Builder
GPU
( SubExp,
SubExp,
SubExp,
SubExp,
SubExp,
SubExp,
SubExp,
SubExp,
SubExp,
SubExp,
SubExp
)
mkTileMemSizes :: SubExp
-> SubExp
-> SubExp
-> Bool
-> Builder
GPU
(SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp, SubExp, SubExp)
mkTileMemSizes SubExp
height_A SubExp
_width_B SubExp
common_dim Bool
is_B_not_transp = do
Name
tk_name <- [Char] -> Name
nameFromString ([Char] -> Name) -> (VName -> [Char]) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> Name)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Tk"
Name
ty_name <- [Char] -> Name
nameFromString ([Char] -> Name) -> (VName -> [Char]) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> Name)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Ty"
Name
ry_name <- [Char] -> Name
nameFromString ([Char] -> Name) -> (VName -> [Char]) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> Name)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Ry"
(SubExp
ty, SubExp
ry) <- ([Char], [Char])
-> (Name, Name) -> SubExp -> Builder GPU (SubExp, SubExp)
getParTiles ([Char]
"Ty", [Char]
"Ry") (Name
ty_name, Name
ry_name) SubExp
height_A
let (SubExp
tx, SubExp
rx) = (SubExp
ty, SubExp
ry)
SubExp
tk <- [Char]
-> Name
-> SubExp
-> SubExp
-> SubExp
-> BuilderT GPU (State VNameSource) SubExp
getSeqTile [Char]
"Tk" Name
tk_name SubExp
common_dim SubExp
tx SubExp
ty
SubExp
tk_div_tx <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tk_div_tx" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
tk SubExp
tx
SubExp
tk_div_ty <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tk_div_ty" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
tk SubExp
ty
SubExp
tx_rx <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"TxRx" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx)
SubExp
ty_ry <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"TyRy" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry)
let pad_term :: TPrimExp Int64 VName
pad_term =
if Bool
is_B_not_transp
then SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry
else SubExp -> TPrimExp Int64 VName
pe64 SubExp
se0
SubExp
a_loc_sz <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"a_loc_sz"
(Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk)
SubExp
b_loc_sz <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"b_loc_sz"
(Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tk TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
pad_term)
(SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp, SubExp, SubExp)
-> Builder
GPU
(SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp,
SubExp, SubExp, SubExp)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
rx, SubExp
ry, SubExp
tx, SubExp
ty, SubExp
tk, SubExp
tk_div_tx, SubExp
tk_div_ty, SubExp
tx_rx, SubExp
ty_ry, SubExp
a_loc_sz, SubExp
b_loc_sz)
mkNewSegthdLvl ::
SubExp ->
SubExp ->
TPrimExp Int64 VName ->
Builder GPU (SubExp, SubExp, SegLevel)
mkNewSegthdLvl :: SubExp
-> SubExp
-> TPrimExp Int64 VName
-> Builder GPU (SubExp, SubExp, SegLevel)
mkNewSegthdLvl SubExp
tx SubExp
ty TPrimExp Int64 VName
grid_pexp = do
SubExp
grid_size <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"grid_size" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
grid_pexp
SubExp
tblock_size <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tblock_size" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)
let segthd_lvl :: SegLevel
segthd_lvl = SegVirt -> SegLevel
SegThreadInBlock (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []))
(SubExp, SubExp, SegLevel)
-> Builder GPU (SubExp, SubExp, SegLevel)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
grid_size, SubExp
tblock_size, SegLevel
segthd_lvl)
mkGidsXYF :: Builder GPU (VName, VName, VName)
mkGidsXYF :: Builder GPU (VName, VName, VName)
mkGidsXYF = do
VName
gid_y <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_y"
VName
gid_x <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_x"
VName
gid_flat <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"
(VName, VName, VName) -> Builder GPU (VName, VName, VName)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
gid_x, VName
gid_y, VName
gid_flat)
initRegShmem ::
(SubExp, SubExp, SubExp, SubExp, SubExp, SubExp) ->
(PrimType, PrimType, PrimType) ->
SegLevel ->
SubExp ->
Builder GPU (VName, VName, VName)
initRegShmem :: (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp)
-> (PrimType, PrimType, PrimType)
-> SegLevel
-> SubExp
-> Builder GPU (VName, VName, VName)
initRegShmem
(SubExp
rx, SubExp
tx, SubExp
ry, SubExp
ty, SubExp
a_loc_sz, SubExp
b_loc_sz)
(PrimType
map_t1, PrimType
map_t2, PrimType
red_t)
SegLevel
segthd_lvl
SubExp
red_ne = do
[VName]
cssss_list <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
segMap2D [Char]
"cssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName])
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName, VName)
_ -> do
VName
css_init <- [Char]
-> PrimType -> [SubExp] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"css_init" PrimType
red_t [SubExp
ry, SubExp
rx]
VName
css <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
ry [VName
css_init] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
css_merge] -> do
VName
css' <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rx [VName
css_merge] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
j [VName
css_merge'] -> do
VName
css'' <- [Char]
-> VName
-> [VName]
-> SubExp
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update [Char]
"css" VName
css_merge' [VName
i, VName
j] SubExp
red_ne
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css'']
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css']
[SubExpRes] -> BuilderT GPU (State VNameSource) [SubExpRes]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
css]
let [VName
cssss] = [VName]
cssss_list
VName
a_loc_init <- [Char]
-> PrimType -> [SubExp] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"A_loc" PrimType
map_t1 [SubExp
a_loc_sz]
VName
b_loc_init <- [Char]
-> PrimType -> [SubExp] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"B_loc" PrimType
map_t2 [SubExp
b_loc_sz]
(VName, VName, VName) -> Builder GPU (VName, VName, VName)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
cssss, VName
a_loc_init, VName
b_loc_init)
getThdRedomapRes ::
(SubExp, SubExp) ->
(VName, VName) ->
(VName, VName, VName) ->
Builder GPU (VName, VName, VName)
getThdRedomapRes :: (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName, VName)
-> Builder GPU (VName, VName, VName)
getThdRedomapRes (SubExp
rx, SubExp
ry) (VName
ltid_x, VName
ltid_y) (VName
iii, VName
jjj, VName
redomap_res) = do
VName
css <- [Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"redomap_thd" VName
redomap_res [VName
ltid_y, VName
ltid_x]
VName
ii <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ii" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
iii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ry)
VName
jj <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jj" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jjj TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rx)
(VName, VName, VName) -> Builder GPU (VName, VName, VName)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
css, VName
ii, VName
jj)
prereqAddCode2 ::
(VName, VName) ->
(VName, VName, VName, VName) ->
(VName, VName) ->
Builder GPU ()
prereqAddCode2 :: (VName, VName)
-> (VName, VName, VName, VName)
-> (VName, VName)
-> BuilderT GPU (State VNameSource) ()
prereqAddCode2 (VName
gtid_x, VName
gtid_y) (VName
ii, VName
i, VName
jj, VName
j) (VName
css, VName
redomap_orig_res) = do
VName
c <- [Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"redomap_elm" VName
css [VName
i, VName
j]
Stm GPU
cpy_stm <- [VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Stm (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName
redomap_orig_res] (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Stm (Rep (BuilderT GPU (State VNameSource)))))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Stm (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
c
Stm (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep (BuilderT GPU (State VNameSource)))
Stm GPU
cpy_stm
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_y] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_x] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
j)
matchCodeStreamCode ::
Stms GPU ->
(Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeStreamCode :: Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeStreamCode Stms GPU
kstms =
let ([Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU]
code2) =
(([Stm GPU], Maybe (Stm GPU), [Stm GPU])
-> Stm GPU -> ([Stm GPU], Maybe (Stm GPU), [Stm GPU]))
-> ([Stm GPU], Maybe (Stm GPU), [Stm GPU])
-> [Stm GPU]
-> ([Stm GPU], Maybe (Stm GPU), [Stm GPU])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc Stm GPU
stmt ->
case (([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc, Stm GPU
stmt) of
(([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op (OtherOp Screma {}))) ->
([Stm GPU]
cd1, Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just Stm GPU
stmt, [Stm GPU]
cd2)
(([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Stm GPU
_) ->
([Stm GPU]
cd1 [Stm GPU] -> [Stm GPU] -> [Stm GPU]
forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt], Maybe (Stm GPU)
forall a. Maybe a
Nothing, [Stm GPU]
cd2)
(([Stm GPU]
cd1, Just Stm GPU
strm, [Stm GPU]
cd2), Stm GPU
_) ->
([Stm GPU]
cd1, Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just Stm GPU
strm, [Stm GPU]
cd2 [Stm GPU] -> [Stm GPU] -> [Stm GPU]
forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt])
)
([], Maybe (Stm GPU)
forall a. Maybe a
Nothing, [])
(Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
kstms)
in ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code2)
isInvarTo1of2InnerDims ::
Names ->
SegSpace ->
VarianceTable ->
[VName] ->
Maybe [Int]
isInvarTo1of2InnerDims :: Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo1of2InnerDims Names
branch_variant SegSpace
kspace Map VName Names
variance [VName]
arrs =
let inner_perm0 :: [Maybe Int]
inner_perm0 = (VName -> Maybe Int) -> [VName] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Maybe Int
varToOnly1of2InnerDims [VName]
arrs
inner_perm :: [Int]
inner_perm = [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
inner_perm0
ok1 :: Bool
ok1 = Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
0 [Int]
inner_perm Bool -> Bool -> Bool
&& Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
1 [Int]
inner_perm
ok2 :: Bool
ok2 = [Maybe Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Int]
inner_perm0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
inner_perm
in if Bool
ok1 Bool -> Bool -> Bool
&& Bool
ok2 then [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
inner_perm else Maybe [Int]
forall a. Maybe a
Nothing
where
varToOnly1of2InnerDims :: VName -> Maybe Int
varToOnly1of2InnerDims :: VName -> Maybe Int
varToOnly1of2InnerDims VName
arr = do
(VName
j, SubExp
_) : (VName
i, SubExp
_) : [(VName, SubExp)]
_ <- [(VName, SubExp)] -> Maybe [(VName, SubExp)]
forall a. a -> Maybe a
Just ([(VName, SubExp)] -> Maybe [(VName, SubExp)])
-> [(VName, SubExp)] -> Maybe [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
let variant_to :: Names
variant_to = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr Map VName Names
variance
branch_invariant :: Bool
branch_invariant =
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
j Names
branch_variant Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
if Bool -> Bool
not Bool
branch_invariant
then Maybe Int
forall a. Maybe a
Nothing
else
if VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`notNameIn` Names
variant_to
then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0
else
if VName -> Names -> Bool
nameIn VName
j Names
variant_to Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`notNameIn` Names
variant_to
then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
1
else Maybe Int
forall a. Maybe a
Nothing
processIndirections ::
Names ->
Names ->
Maybe (Stms GPU, M.Map VName (Stm GPU)) ->
Stm GPU ->
Maybe (Stms GPU, M.Map VName (Stm GPU))
processIndirections :: Names
-> Names
-> Maybe (Stms GPU, Map VName (Stm GPU))
-> Stm GPU
-> Maybe (Stms GPU, Map VName (Stm GPU))
processIndirections Names
arrs Names
_ Maybe (Stms GPU, Map VName (Stm GPU))
acc stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
patt StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
_)))
| Just (Stms GPU
ss, Map VName (Stm GPU)
tab) <- Maybe (Stms GPU, Map VName (Stm GPU))
acc,
[PatElem (TypeBase Shape NoUniqueness)
p] <- Pat (TypeBase Shape NoUniqueness)
-> [PatElem (TypeBase Shape NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (TypeBase Shape NoUniqueness)
Pat (LetDec GPU)
patt,
VName
p_nm <- PatElem (TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
p,
VName
p_nm VName -> Names -> Bool
`nameIn` Names
arrs =
(Stms GPU, Map VName (Stm GPU))
-> Maybe (Stms GPU, Map VName (Stm GPU))
forall a. a -> Maybe a
Just (Stms GPU
ss, VName -> Stm GPU -> Map VName (Stm GPU) -> Map VName (Stm GPU)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_nm Stm GPU
stm Map VName (Stm GPU)
tab)
processIndirections Names
_ Names
res_red_var Maybe (Stms GPU, Map VName (Stm GPU))
acc stm' :: Stm GPU
stm'@(Let Pat (LetDec GPU)
patt StmAux (ExpDec GPU)
_ Exp GPU
_)
| Just (Stms GPU
ss, Map VName (Stm GPU)
tab) <- Maybe (Stms GPU, Map VName (Stm GPU))
acc,
[PatElem (TypeBase Shape NoUniqueness)]
ps <- Pat (TypeBase Shape NoUniqueness)
-> [PatElem (TypeBase Shape NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (TypeBase Shape NoUniqueness)
Pat (LetDec GPU)
patt,
(PatElem (TypeBase Shape NoUniqueness) -> Bool)
-> [PatElem (TypeBase Shape NoUniqueness)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\PatElem (TypeBase Shape NoUniqueness)
p -> PatElem (TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
p VName -> Names -> Bool
`notNameIn` Names
res_red_var) [PatElem (TypeBase Shape NoUniqueness)]
ps =
(Stms GPU, Map VName (Stm GPU))
-> Maybe (Stms GPU, Map VName (Stm GPU))
forall a. a -> Maybe a
Just (Stms GPU
ss Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
Seq.|> Stm GPU
stm', Map VName (Stm GPU)
tab)
| Bool
otherwise = Maybe (Stms GPU, Map VName (Stm GPU))
forall a. Maybe a
Nothing
getParTiles :: (String, String) -> (Name, Name) -> SubExp -> Builder GPU (SubExp, SubExp)
getParTiles :: ([Char], [Char])
-> (Name, Name) -> SubExp -> Builder GPU (SubExp, SubExp)
getParTiles ([Char]
t_str, [Char]
r_str) (Name
t_name, Name
r_name) SubExp
len_dim =
case SubExp
len_dim of
Constant (IntValue (Int64Value Int64
8)) ->
(SubExp, SubExp) -> Builder GPU (SubExp, SubExp)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
se8, SubExp
se1)
Constant (IntValue (Int64Value Int64
16)) ->
(SubExp, SubExp) -> Builder GPU (SubExp, SubExp)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
se8, SubExp
se2)
Constant (IntValue (Int64Value Int64
32)) ->
(SubExp, SubExp) -> Builder GPU (SubExp, SubExp)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
se8, SubExp
se4)
SubExp
_ -> do
SubExp
t <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
t_str (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp SOAC GPU) -> SizeOp -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
t_name SizeClass
SizeTile
SubExp
r <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
r_str (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp SOAC GPU) -> SizeOp -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
r_name SizeClass
SizeRegTile
(SubExp, SubExp) -> Builder GPU (SubExp, SubExp)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
t, SubExp
r)
getSeqTile :: String -> Name -> SubExp -> SubExp -> SubExp -> Builder GPU SubExp
getSeqTile :: [Char]
-> Name
-> SubExp
-> SubExp
-> SubExp
-> BuilderT GPU (State VNameSource) SubExp
getSeqTile [Char]
tk_str Name
tk_name SubExp
len_dim SubExp
tx SubExp
ty =
case (SubExp
tx, SubExp
ty) of
(Constant (IntValue (Int64Value Int64
v_x)), Constant (IntValue (Int64Value Int64
v_y))) ->
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
tk_str (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> (Int64 -> Exp GPU)
-> Int64
-> BuilderT GPU (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> (Int64 -> BasicOp) -> Int64 -> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (Int64 -> SubExp) -> Int64 -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64 -> BuilderT GPU (State VNameSource) SubExp)
-> Int64 -> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
case SubExp
len_dim of
Constant (IntValue (Int64Value Int64
v_d)) -> Int64 -> Int64 -> Int64
forall a. Ord a => a -> a -> a
min Int64
v_d (Int64 -> Int64) -> Int64 -> Int64
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64 -> Int64
forall a. Ord a => a -> a -> a
min Int64
v_x Int64
v_y
SubExp
_ -> Int64 -> Int64 -> Int64
forall a. Ord a => a -> a -> a
min Int64
v_x Int64
v_y
(SubExp, SubExp)
_ ->
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
tk_str (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp SOAC GPU) -> SizeOp -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tk_name SizeClass
SizeTile
maxRegTile :: Int64
maxRegTile :: Int64
maxRegTile = Int64
30
mkRegTileSe :: Int64 -> SubExp
mkRegTileSe :: Int64 -> SubExp
mkRegTileSe = Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant
variantToDim :: VarianceTable -> VName -> VName -> Bool
variantToDim :: Map VName Names -> VName -> VName -> Bool
variantToDim Map VName Names
variance VName
gid_outer VName
nm =
VName
gid_outer VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
nm Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid_outer (Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
nm Map VName Names
variance)
isInvarTo2of3InnerDims ::
Names ->
SegSpace ->
VarianceTable ->
[VName] ->
Maybe [Int]
isInvarTo2of3InnerDims :: Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo2of3InnerDims Names
branch_variant SegSpace
kspace Map VName Names
variance [VName]
arrs =
let inner_perm0 :: [Maybe Int]
inner_perm0 = (VName -> Maybe Int) -> [VName] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Maybe Int
varToOnly1of3InnerDims [VName]
arrs
inner_perm :: [Int]
inner_perm = [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
inner_perm0
ok1 :: Bool
ok1 = Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
0 [Int]
inner_perm Bool -> Bool -> Bool
&& Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
1 [Int]
inner_perm Bool -> Bool -> Bool
&& Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Int
2 [Int]
inner_perm
ok2 :: Bool
ok2 = [Maybe Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Int]
inner_perm0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
inner_perm
in if Bool
ok1 Bool -> Bool -> Bool
&& Bool
ok2 then [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
inner_perm else Maybe [Int]
forall a. Maybe a
Nothing
where
varToOnly1of3InnerDims :: VName -> Maybe Int
varToOnly1of3InnerDims :: VName -> Maybe Int
varToOnly1of3InnerDims VName
arr = do
(VName
k, SubExp
_) : (VName
j, SubExp
_) : (VName
i, SubExp
_) : [(VName, SubExp)]
_ <- [(VName, SubExp)] -> Maybe [(VName, SubExp)]
forall a. a -> Maybe a
Just ([(VName, SubExp)] -> Maybe [(VName, SubExp)])
-> [(VName, SubExp)] -> Maybe [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
let variant_to :: Names
variant_to = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr Map VName Names
variance
branch_invariant :: Bool
branch_invariant =
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
VName -> Names -> Bool
nameIn VName
k Names
branch_variant
Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
j Names
branch_variant
Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
if Bool -> Bool
not Bool
branch_invariant
then Maybe Int
forall a. Maybe a
Nothing
else
if VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
j Names
variant_to Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
k Names
variant_to)
then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0
else
if VName -> Names -> Bool
nameIn VName
j Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
k Names
variant_to)
then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
1
else
if VName -> Names -> Bool
nameIn VName
k Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
i Names
variant_to Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
j Names
variant_to)
then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
2
else Maybe Int
forall a. Maybe a
Nothing
doRegTiling3D :: Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
doRegTiling3D :: Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU))
doRegTiling3D (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp SegOp SegLevel GPU
old_kernel)))
| SegMap SegThread {} SegSpace
space [TypeBase Shape NoUniqueness]
kertp (KernelBody () Stms GPU
kstms [KernelResult]
kres) <- SegOp SegLevel GPU
old_kernel,
Map VName Names
initial_variance <- (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> Map VName Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space,
Map VName Names
variance <- Map VName Names -> Stms GPU -> Map VName Names
varianceInStms Map VName Names
initial_variance Stms GPU
kstms,
(VName
gtid_x, SubExp
d_Kx) : (VName
gtid_y, SubExp
d_Ky) : (VName
gtid_z, SubExp
d_M) : [(VName, SubExp)]
rem_outer_dims_rev <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
[(VName, SubExp)]
rem_outer_dims <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
rem_outer_dims_rev,
(Stms GPU
code1, Just Stm GPU
screma_stmt, Stms GPU
code2) <- Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeStreamCode Stms GPU
kstms,
Let Pat (LetDec GPU)
pat_redomap StmAux (ExpDec GPU)
_ (Op Op GPU
_) <- Stm GPU
screma_stmt,
Just (SubExp
common_dim, [VName]
inp_soac_arrs, (Commutativity
_, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam)) <- Stm GPU
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
screma_stmt,
Bool -> Bool
not ([SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
red_nes),
Int
num_res <- Int -> Int -> Int
forall a. Ord a => a -> a -> a
max ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres),
Int64
reg_tile <- Int64
maxRegTile Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
`quot` Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_res,
SubExp
reg_tile_se <- Int64 -> SubExp
mkRegTileSe Int64
reg_tile,
(Param (TypeBase Shape NoUniqueness) -> Bool)
-> [Param (TypeBase Shape NoUniqueness)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> (Param (TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness)
-> Param (TypeBase Shape NoUniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Param dec -> dec
paramDec) ([Param (TypeBase Shape NoUniqueness)] -> Bool)
-> [Param (TypeBase Shape NoUniqueness)] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam,
[TypeBase Shape NoUniqueness]
red_res_tps <- (Param (TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Param dec -> dec
paramDec ([Param (TypeBase Shape NoUniqueness)]
-> [TypeBase Shape NoUniqueness])
-> [Param (TypeBase Shape NoUniqueness)]
-> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)])
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam,
(TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType [TypeBase Shape NoUniqueness]
red_res_tps,
Just [Int]
_ <- Names -> SegSpace -> Map VName Names -> [VName] -> Maybe [Int]
isInvarTo2of3InnerDims Names
forall a. Monoid a => a
mempty SegSpace
space Map VName Names
variance [VName]
inp_soac_arrs,
[PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res <- Pat (TypeBase Shape NoUniqueness)
-> [PatElem (TypeBase Shape NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (TypeBase Shape NoUniqueness)
Pat (LetDec GPU)
pat_redomap,
Names
res_red_var <-
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElem (TypeBase Shape NoUniqueness) -> Maybe Names)
-> [PatElem (TypeBase Shape NoUniqueness)] -> [Names]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((VName -> Map VName Names -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName Names
variance) (VName -> Maybe Names)
-> (PatElem (TypeBase Shape NoUniqueness) -> VName)
-> PatElem (TypeBase Shape NoUniqueness)
-> Maybe Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res,
Names
forall a. Monoid a => a
mempty Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
res_red_var,
Just (Stms GPU
code2'', Map VName (Stm GPU)
arr_tab0) <-
(Maybe (Stms GPU, Map VName (Stm GPU))
-> Stm GPU -> Maybe (Stms GPU, Map VName (Stm GPU)))
-> Maybe (Stms GPU, Map VName (Stm GPU))
-> Stms GPU
-> Maybe (Stms GPU, Map VName (Stm GPU))
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
(Names
-> Names
-> Maybe (Stms GPU, Map VName (Stm GPU))
-> Stm GPU
-> Maybe (Stms GPU, Map VName (Stm GPU))
processIndirections ([VName] -> Names
namesFromList [VName]
inp_soac_arrs) Names
res_red_var)
((Stms GPU, Map VName (Stm GPU))
-> Maybe (Stms GPU, Map VName (Stm GPU))
forall a. a -> Maybe a
Just (Stms GPU
forall a. Seq a
Seq.empty, Map VName (Stm GPU)
forall k a. Map k a
M.empty))
Stms GPU
code1,
[Stm GPU]
tmp_stms <- (VName -> Maybe (Stm GPU)) -> [VName] -> [Stm GPU]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName (Stm GPU) -> Maybe (Stm GPU)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Stm GPU)
arr_tab0) [VName]
inp_soac_arrs,
[Stm GPU] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Stm GPU]
tmp_stms Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
inp_soac_arrs,
Stms GPU
code2' <- Stms GPU
code2'' Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
code2,
[VName]
ker_res_nms <- (KernelResult -> Maybe VName) -> [KernelResult] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelResult -> Maybe VName
getResNm [KernelResult]
kres,
[VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_res_nms Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres,
(TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType [TypeBase Shape NoUniqueness]
kertp,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Map VName Names -> VName -> VName -> Bool
variantToDim Map VName Names
variance VName
gtid_z) [VName]
ker_res_nms = do
(Stm GPU
new_kernel, Stms GPU
host_stms) <- Builder GPU (Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU, Stms GPU))
-> Builder GPU (Stm GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stm GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
(Map VName (Stm GPU)
tab_inn, Map VName (PrimType, Stm GPU)
tab_out) <-
((Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
-> (VName, Stm GPU)
-> BuilderT
GPU
(State VNameSource)
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU)))
-> (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
-> [(VName, Stm GPU)]
-> BuilderT
GPU
(State VNameSource)
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
(Map VName Names
-> (VName, SubExp)
-> (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
-> (VName, Stm GPU)
-> BuilderT
GPU
(State VNameSource)
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
insertTranspose Map VName Names
variance (VName
gtid_z, SubExp
d_M))
(Map VName (Stm GPU)
forall k a. Map k a
M.empty, Map VName (PrimType, Stm GPU)
forall k a. Map k a
M.empty)
([(VName, Stm GPU)]
-> BuilderT
GPU
(State VNameSource)
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU)))
-> [(VName, Stm GPU)]
-> BuilderT
GPU
(State VNameSource)
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
forall a b. (a -> b) -> a -> b
$ Map VName (Stm GPU) -> [(VName, Stm GPU)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (Stm GPU)
arr_tab0
Name
tx_name <- [Char] -> Name
nameFromString ([Char] -> Name) -> (VName -> [Char]) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> Name)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Tx"
Name
ty_name <- [Char] -> Name
nameFromString ([Char] -> Name) -> (VName -> [Char]) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> Name)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"Ty"
SubExp
tx0 <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"Tx" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp SOAC GPU) -> SizeOp -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tx_name SizeClass
SizeTile
SubExp
ty0 <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"Ty" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp SOAC GPU) -> SizeOp -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
ty_name SizeClass
SizeTile
SubExp
ty <- [Char]
-> SubExp -> SubExp -> BuilderT GPU (State VNameSource) SubExp
limitTile [Char]
"Ty" SubExp
ty0 SubExp
d_Ky
SubExp
tx <- [Char]
-> SubExp -> SubExp -> BuilderT GPU (State VNameSource) SubExp
limitTile [Char]
"Tx" SubExp
tx0 SubExp
d_Kx
let rz :: SubExp
rz = SubExp
reg_tile_se
SubExp
gridDim_x <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_x" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
d_Kx SubExp
tx
SubExp
gridDim_y <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_y" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
d_Ky SubExp
ty
SubExp
gridDim_z <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"gridDim_z" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
d_M SubExp
rz
let gridxyz_pexp :: TPrimExp Int64 VName
gridxyz_pexp = SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_z TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
gridDim_x
let grid_pexp :: TPrimExp Int64 VName
grid_pexp = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
gridxyz_pexp TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
rem_outer_dims_rev
SubExp
grid_size <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"grid_size_tile3d" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
grid_pexp
SubExp
tblock_size <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tblock_size_tile3d" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)
let segthd_lvl :: SegLevel
segthd_lvl = SegVirt -> SegLevel
SegThreadInBlock (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []))
SubExp
count_shmem <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"count_shmem" (Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp
-> SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (Exp (Rep m))
ceilDiv SubExp
rz SubExp
tblock_size
VName
gid_x <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_x"
VName
gid_y <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_y"
VName
gid_z <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_z"
VName
gid_flat <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"
([KernelResult]
ret_seggroup, Stms GPU
stms_seggroup) <- Builder GPU [KernelResult]
-> BuilderT GPU (State VNameSource) ([KernelResult], Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU [KernelResult]
-> BuilderT GPU (State VNameSource) ([KernelResult], Stms GPU))
-> Builder GPU [KernelResult]
-> BuilderT GPU (State VNameSource) ([KernelResult], Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
VName
ii <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ii" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_z TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
rz)
VName
jj1 <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jj1" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
ty)
VName
jj2 <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"jj2" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tx)
[VName]
reg_arr_nms <- [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
segMap2D [Char]
"res" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName])
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName, VName)
_ ->
[(SubExp, TypeBase Shape NoUniqueness)]
-> ((SubExp, TypeBase Shape NoUniqueness)
-> BuilderT GPU (State VNameSource) SubExpRes)
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp]
-> [TypeBase Shape NoUniqueness]
-> [(SubExp, TypeBase Shape NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
red_nes [TypeBase Shape NoUniqueness]
red_res_tps) (((SubExp, TypeBase Shape NoUniqueness)
-> BuilderT GPU (State VNameSource) SubExpRes)
-> BuilderT GPU (State VNameSource) [SubExpRes])
-> ((SubExp, TypeBase Shape NoUniqueness)
-> BuilderT GPU (State VNameSource) SubExpRes)
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall a b. (a -> b) -> a -> b
$ \(SubExp
red_ne, TypeBase Shape NoUniqueness
red_t) -> do
VName
css_init <- [Char]
-> PrimType -> [SubExp] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"res_init" (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
red_t) [SubExp
rz]
VName
css <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rz [VName
css_init] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
css_merge] -> do
VName
css' <- [Char]
-> VName
-> [VName]
-> SubExp
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update [Char]
"css" VName
css_merge [VName
i] SubExp
red_ne
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
css']
SubExpRes -> BuilderT GPU (State VNameSource) SubExpRes
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> BuilderT GPU (State VNameSource) SubExpRes)
-> SubExpRes -> BuilderT GPU (State VNameSource) SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes VName
css
[VName]
loc_arr_nms <- [(VName, (PrimType, Stm GPU))]
-> ((VName, (PrimType, Stm GPU))
-> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Map VName (PrimType, Stm GPU) -> [(VName, (PrimType, Stm GPU))]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (PrimType, Stm GPU)
tab_out) (((VName, (PrimType, Stm GPU))
-> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> ((VName, (PrimType, Stm GPU))
-> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
nm, (PrimType
ptp, Stm GPU
_)) ->
[Char]
-> PrimType -> [SubExp] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch (VName -> [Char]
baseString VName
nm [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_loc") PrimType
ptp [SubExp
rz]
[VName]
prologue_res_list <-
SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
common_dim ([VName]
reg_arr_nms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
loc_arr_nms) ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName])
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
\VName
q [VName]
var_nms -> do
let reg_arr_merge_nms :: [VName]
reg_arr_merge_nms = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) [VName]
var_nms
let loc_arr_merge_nms :: [VName]
loc_arr_merge_nms = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) [VName]
var_nms
[VName]
loc_arr_nms' <-
SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
count_shmem [VName]
loc_arr_merge_nms ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName])
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \VName
tt [VName]
loc_arr_merge2_nms -> do
[VName]
loc_arr_merge2_nms' <-
[(VName, (VName, (PrimType, Stm GPU)))]
-> ((VName, (VName, (PrimType, Stm GPU)))
-> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName]
-> [(VName, (PrimType, Stm GPU))]
-> [(VName, (VName, (PrimType, Stm GPU)))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
loc_arr_merge2_nms (Map VName (PrimType, Stm GPU) -> [(VName, (PrimType, Stm GPU))]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (PrimType, Stm GPU)
tab_out)) (((VName, (VName, (PrimType, Stm GPU)))
-> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> ((VName, (VName, (PrimType, Stm GPU)))
-> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
loc_Y_nm, (VName
glb_Y_nm, (PrimType
ptp_Y, Stm GPU
load_Y))) -> do
VName
ltid_flat <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"ltid_flat"
VName
ltid <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"ltid"
let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, SubExp
tblock_size)]
((SubExp
res_v, SubExp
res_i), Stms GPU
stms) <- Builder GPU (SubExp, SubExp)
-> BuilderT GPU (State VNameSource) ((SubExp, SubExp), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (SubExp, SubExp)
-> BuilderT GPU (State VNameSource) ((SubExp, SubExp), Stms GPU))
-> Builder GPU (SubExp, SubExp)
-> BuilderT GPU (State VNameSource) ((SubExp, SubExp), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
VName
offs <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"offs" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tblock_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
tt)
VName
loc_ind <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"loc_ind" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
offs)
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_z] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
loc_ind)
let glb_ind :: VName
glb_ind = VName
gtid_z
SubExp
y_elm <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"y_elem"
(Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
glb_ind TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_M)
( do
Stm (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep (BuilderT GPU (State VNameSource)))
Stm GPU
load_Y
VName
res <- [Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"Y_elem" VName
glb_Y_nm [VName
q]
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [VName -> SubExp
Var VName
res]
)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [TypeBase Shape NoUniqueness
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
TypeBase Shape NoUniqueness -> m (Exp (Rep m))
eBlank (TypeBase Shape NoUniqueness
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TypeBase Shape NoUniqueness
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ptp_Y])
SubExp
y_ind <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"y_loc_ind"
(Exp GPU -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
loc_ind TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
rz)
(VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
toExp VName
loc_ind BuilderT GPU (State VNameSource) (Exp GPU)
-> (Exp GPU -> BuilderT GPU (State VNameSource) [SubExp])
-> BuilderT GPU (State VNameSource) [SubExp]
forall a b.
BuilderT GPU (State VNameSource) a
-> (a -> BuilderT GPU (State VNameSource) b)
-> BuilderT GPU (State VNameSource) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"loc_fi" BuilderT GPU (State VNameSource) [SubExp]
-> ([SubExp] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) (Body GPU)
forall a b.
BuilderT GPU (State VNameSource) a
-> (a -> BuilderT GPU (State VNameSource) b)
-> BuilderT GPU (State VNameSource) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
[SubExp] -> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)])
(SubExp, SubExp) -> Builder GPU (SubExp, SubExp)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
y_elm, SubExp
y_ind)
let ret :: KernelResult
ret = Certs -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
forall a. Monoid a => a
mempty VName
loc_Y_nm [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
res_i], SubExp
res_v)]
let body :: KernelBody GPU
body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms [KernelResult
ret]
TypeBase Shape NoUniqueness
loc_Y_nm_t <- VName
-> BuilderT GPU (State VNameSource) (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
loc_Y_nm
[VName]
res_nms <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"Y_glb2loc" (Exp GPU -> Builder GPU [VName])
-> (Exp GPU -> BuilderT GPU (State VNameSource) (Exp GPU))
-> Exp GPU
-> Builder GPU [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Exp GPU -> BuilderT GPU (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> Builder GPU [VName]) -> Exp GPU -> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
Op GPU -> Exp GPU
HostOp SOAC GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (HostOp SOAC GPU -> Exp GPU)
-> (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU
-> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> Exp GPU) -> SegOp SegLevel GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$
SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
segthd_lvl SegSpace
segspace [TypeBase Shape NoUniqueness
loc_Y_nm_t] KernelBody GPU
body
let VName
res_nm : [VName]
_ = [VName]
res_nms
VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
res_nm
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
loc_arr_merge2_nms'
[VName]
redomap_res <-
[Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
segMap2D [Char]
"redomap_res" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
ty, SubExp
tx) (((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName])
-> ((VName, VName) -> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
\(VName
ltid_y, VName
ltid_x) -> do
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_y] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_x] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj2 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
[VName]
reg_arr_merge_nms_slc <- [VName]
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
reg_arr_merge_nms ((VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \VName
reg_arr_nm ->
[Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"res_reg_slc" VName
reg_arr_nm [VName
ltid_y, VName
ltid_x]
([SubExp] -> [SubExpRes])
-> BuilderT GPU (State VNameSource) [SubExp]
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (BuilderT GPU (State VNameSource) [SubExp]
-> BuilderT GPU (State VNameSource) [SubExpRes])
-> (Exp GPU -> BuilderT GPU (State VNameSource) [SubExp])
-> Exp GPU
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"redomap_guarded"
(Exp GPU -> BuilderT GPU (State VNameSource) [SubExpRes])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Ky TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Kx)
( do
[VName]
inp_scals_invar_outer <-
[(VName, Stm GPU)]
-> ((VName, Stm GPU) -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Map VName (Stm GPU) -> [(VName, Stm GPU)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (Stm GPU)
tab_inn) (((VName, Stm GPU) -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> ((VName, Stm GPU) -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
inp_arr_nm, Stm GPU
load_stm) -> do
Stm (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep (BuilderT GPU (State VNameSource)))
Stm GPU
load_stm
[Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index (VName -> [Char]
baseString VName
inp_arr_nm) VName
inp_arr_nm [VName
q]
[VName]
reg_arr_merge_nms' <-
SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
rz [VName]
reg_arr_merge_nms_slc ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName])
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \VName
i [VName]
reg_arr_mm_nms -> do
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_z] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
[SubExp] -> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM
([SubExp] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) [SubExp]
-> BuilderT GPU (State VNameSource) (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"redomap_lam"
(Exp GPU -> BuilderT GPU (State VNameSource) [SubExp])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_z TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_M)
( do
[VName]
ys <- [VName]
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
loc_arr_nms' ((VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \VName
loc_arr_nm ->
[Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"inp_reg_var2z" VName
loc_arr_nm [VName
i]
[VName]
cs <- [VName]
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
reg_arr_mm_nms ((VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \VName
reg_arr_nm ->
[Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"res_reg_var2z" VName
reg_arr_nm [VName
i]
let tab_scals :: Map VName VName
tab_scals =
[(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$
[VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, (PrimType, Stm GPU)) -> VName)
-> [(VName, (PrimType, Stm GPU))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (PrimType, Stm GPU)) -> VName
forall a b. (a, b) -> a
fst ([(VName, (PrimType, Stm GPU))] -> [VName])
-> [(VName, (PrimType, Stm GPU))] -> [VName]
forall a b. (a -> b) -> a -> b
$ Map VName (PrimType, Stm GPU) -> [(VName, (PrimType, Stm GPU))]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (PrimType, Stm GPU)
tab_out) [VName]
ys
[(VName, VName)] -> [(VName, VName)] -> [(VName, VName)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, Stm GPU) -> VName) -> [(VName, Stm GPU)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Stm GPU) -> VName
forall a b. (a, b) -> a
fst ([(VName, Stm GPU)] -> [VName]) -> [(VName, Stm GPU)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Map VName (Stm GPU) -> [(VName, Stm GPU)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (Stm GPU)
tab_inn) [VName]
inp_scals_invar_outer
[VName]
map_inp_scals <- [VName]
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
inp_soac_arrs ((VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \VName
arr_nm ->
case VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr_nm Map VName VName
tab_scals of
Maybe VName
Nothing -> [Char] -> BuilderT GPU (State VNameSource) VName
forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible case reached in tiling3D\n"
Just VName
nm -> VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
nm
Lambda GPU
map_lam' <- Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam
Lambda GPU
red_lam' <- Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
red_lam
[SubExpRes]
map_res_scals <- Lambda (Rep (BuilderT GPU (State VNameSource)))
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep (BuilderT GPU (State VNameSource)))
Lambda GPU
map_lam' ((VName -> BuilderT GPU (State VNameSource) (Exp GPU))
-> [VName] -> [BuilderT GPU (State VNameSource) (Exp GPU)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
SubExp -> BuilderT GPU (State VNameSource) (Exp GPU)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> BuilderT GPU (State VNameSource) (Exp GPU))
-> (VName -> SubExp)
-> VName
-> BuilderT GPU (State VNameSource) (Exp GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
map_inp_scals)
[SubExpRes]
red_res <- Lambda (Rep (BuilderT GPU (State VNameSource)))
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep (BuilderT GPU (State VNameSource)))
Lambda GPU
red_lam' ((SubExp -> BuilderT GPU (State VNameSource) (Exp GPU))
-> [SubExp] -> [BuilderT GPU (State VNameSource) (Exp GPU)]
forall a b. (a -> b) -> [a] -> [b]
map SubExp
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
SubExp -> BuilderT GPU (State VNameSource) (Exp GPU)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
cs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
map_res_scals))
[VName]
css <- [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [SubExpRes] -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
reg_arr_mm_nms [SubExpRes]
red_res) (((VName, SubExpRes) -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> ((VName, SubExpRes) -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
reg_arr_nm, SubExpRes
c) ->
[Char]
-> VName
-> [VName]
-> SubExp
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> SubExp -> m VName
update (VName -> [Char]
baseString VName
reg_arr_nm) VName
reg_arr_nm [VName
i] (SubExpRes -> SubExp
resSubExp SubExpRes
c)
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
css
)
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
reg_arr_mm_nms)
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
reg_arr_merge_nms'
)
([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
reg_arr_merge_nms_slc)
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [VName]
redomap_res [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
loc_arr_nms'
let redomap_res :: [VName]
redomap_res = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) [VName]
prologue_res_list
[VName]
epilogue_res <-
if [PatElem (TypeBase Shape NoUniqueness)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_res_nms
Bool -> Bool -> Bool
&& [VName]
ker_res_nms [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElem (TypeBase Shape NoUniqueness) -> VName)
-> [PatElem (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res
then [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName)
-> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
segMap3D [Char]
"rssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
se1, SubExp
ty, SubExp
tx) (((VName, VName, VName)
-> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName])
-> ((VName, VName, VName)
-> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
_ltid_z, VName
ltid_y, VName
ltid_x) ->
[(TypeBase Shape NoUniqueness, VName)]
-> ((TypeBase Shape NoUniqueness, VName)
-> BuilderT GPU (State VNameSource) SubExpRes)
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([TypeBase Shape NoUniqueness]
-> [VName] -> [(TypeBase Shape NoUniqueness, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TypeBase Shape NoUniqueness]
kertp [VName]
redomap_res) (((TypeBase Shape NoUniqueness, VName)
-> BuilderT GPU (State VNameSource) SubExpRes)
-> BuilderT GPU (State VNameSource) [SubExpRes])
-> ((TypeBase Shape NoUniqueness, VName)
-> BuilderT GPU (State VNameSource) SubExpRes)
-> BuilderT GPU (State VNameSource) [SubExpRes]
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
res_tp, VName
res) -> do
VName
rss_init <- [Char]
-> PrimType -> [SubExp] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"rss_init" (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
res_tp) [SubExp
rz, SubExp
se1, SubExp
se1]
(VName -> SubExpRes)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) SubExpRes
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExpRes
varRes (BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) SubExpRes)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) SubExpRes
forall a b. (a -> b) -> a -> b
$
SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
rz [VName
rss_init] ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName)
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ \VName
i [VName
rss] -> do
let slice :: Slice SubExp
slice = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
se0, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
se0]
VName
thread_res <- [Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"thread_res" VName
res [VName
ltid_y, VName
ltid_x, VName
i]
SubExp
rss' <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rss" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
rss Slice SubExp
slice (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
thread_res
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
rss']
else [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName)
-> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
segMap3D [Char]
"rssss" SegLevel
segthd_lvl ResultManifest
ResultPrivate (SubExp
se1, SubExp
ty, SubExp
tx) (((VName, VName, VName)
-> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName])
-> ((VName, VName, VName)
-> BuilderT GPU (State VNameSource) [SubExpRes])
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
_ltid_z, VName
ltid_y, VName
ltid_x) -> do
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_y] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_x] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
jj2 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
[VName]
rss_init <- [TypeBase Shape NoUniqueness]
-> (TypeBase Shape NoUniqueness
-> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase Shape NoUniqueness]
kertp ((TypeBase Shape NoUniqueness
-> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> (TypeBase Shape NoUniqueness
-> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \TypeBase Shape NoUniqueness
res_tp ->
[Char]
-> PrimType -> [SubExp] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> PrimType -> [SubExp] -> m VName
scratch [Char]
"rss_init" (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
res_tp) [SubExp
rz, SubExp
se1, SubExp
se1]
[VName]
rss <- SubExp
-> [VName]
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
rz [VName]
rss_init ((VName -> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName])
-> (VName
-> [VName] -> BuilderT GPU (State VNameSource) (Body GPU))
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \VName
i [VName]
rss_merge -> do
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_z] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ii TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
[(PatElem (TypeBase Shape NoUniqueness), VName)]
-> ((PatElem (TypeBase Shape NoUniqueness), VName)
-> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem (TypeBase Shape NoUniqueness)]
-> [VName] -> [(PatElem (TypeBase Shape NoUniqueness), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (TypeBase Shape NoUniqueness)]
redomap_orig_res [VName]
redomap_res) (((PatElem (TypeBase Shape NoUniqueness), VName)
-> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) ())
-> ((PatElem (TypeBase Shape NoUniqueness), VName)
-> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (TypeBase Shape NoUniqueness)
o_res, VName
n_res) -> do
VName
c <- [Char]
-> VName -> [VName] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> [VName] -> m VName
index [Char]
"redomap_thd" VName
n_res [VName
ltid_y, VName
ltid_x, VName
i]
[VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
o_res] (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
c)
VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
c
[SubExp]
res_els <-
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"res_elem"
(Exp GPU -> BuilderT GPU (State VNameSource) [SubExp])
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource)))))
-> TPrimExp Bool VName
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Ky
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_Kx
TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_z
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
d_M
)
( do
Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
code2'
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ker_res_nms
)
([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource)))))
-> [BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness
-> BuilderT GPU (State VNameSource) (Exp GPU))
-> [TypeBase Shape NoUniqueness]
-> [BuilderT GPU (State VNameSource) (Exp GPU)]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness
-> BuilderT
GPU
(State VNameSource)
(Exp (Rep (BuilderT GPU (State VNameSource))))
TypeBase Shape NoUniqueness
-> BuilderT GPU (State VNameSource) (Exp GPU)
forall (m :: * -> *).
MonadBuilder m =>
TypeBase Shape NoUniqueness -> m (Exp (Rep m))
eBlank [TypeBase Shape NoUniqueness]
kertp)
[SubExp]
rss' <- [(SubExp, VName)]
-> ((SubExp, VName) -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
res_els [VName]
rss_merge) (((SubExp, VName) -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) [SubExp])
-> ((SubExp, VName) -> BuilderT GPU (State VNameSource) SubExp)
-> BuilderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \(SubExp
res_el, VName
rs_merge) -> do
let slice :: Slice SubExp
slice = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
se0, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
se0]
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rss" (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
rs_merge Slice SubExp
slice SubExp
res_el
[SubExp]
-> BuilderT
GPU
(State VNameSource)
(Body (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
rss'
[SubExpRes] -> BuilderT GPU (State VNameSource) [SubExpRes]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExpRes] -> BuilderT GPU (State VNameSource) [SubExpRes])
-> [SubExpRes] -> BuilderT GPU (State VNameSource) [SubExpRes]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes [VName]
rss
let regtile_ret_dims :: [(SubExp, SubExp, SubExp)]
regtile_ret_dims =
((VName, SubExp) -> (SubExp, SubExp, SubExp))
-> [(VName, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map (\(VName
_, SubExp
sz) -> (SubExp
sz, SubExp
se1, SubExp
se1)) [(VName, SubExp)]
rem_outer_dims
[(SubExp, SubExp, SubExp)]
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(SubExp
d_M, SubExp
se1, SubExp
rz), (SubExp
d_Ky, SubExp
ty, SubExp
se1), (SubExp
d_Kx, SubExp
tx, SubExp
se1)]
[VName]
epilogue_res' <- [VName]
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
epilogue_res ((VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName])
-> (VName -> BuilderT GPU (State VNameSource) VName)
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ \VName
res ->
if [(VName, SubExp)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
rem_outer_dims
then VName -> BuilderT GPU (State VNameSource) VName
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
res
else do
TypeBase Shape NoUniqueness
res_tp' <- VName
-> BuilderT GPU (State VNameSource) (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
res
let ([SubExp]
block_dims, [SubExp]
rest_dims) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
2 ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
res_tp'
ones :: [SubExp]
ones = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> (VName, SubExp) -> SubExp
forall a b. a -> b -> a
const SubExp
se1) [(VName, SubExp)]
rem_outer_dims
new_shape :: Shape
new_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]
ones, [SubExp]
block_dims, [SubExp]
ones, [SubExp]
rest_dims]
[Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"res_reshaped" (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> (BasicOp -> Exp GPU)
-> BasicOp
-> BuilderT GPU (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT GPU (State VNameSource) VName)
-> BasicOp -> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary Shape
new_shape VName
res
[KernelResult] -> Builder GPU [KernelResult]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([KernelResult] -> Builder GPU [KernelResult])
-> [KernelResult] -> Builder GPU [KernelResult]
forall a b. (a -> b) -> a -> b
$ (VName -> KernelResult) -> [VName] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns Certs
forall a. Monoid a => a
mempty [(SubExp, SubExp, SubExp)]
regtile_ret_dims) [VName]
epilogue_res'
let grid :: KernelGrid
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
grid_size) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
tblock_size)
level' :: SegLevel
level' = SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
SegNoVirt (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)
space' :: SegSpace
space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)]
rem_outer_dims [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_z, SubExp
gridDim_z), (VName
gid_y, SubExp
gridDim_y), (VName
gid_x, SubExp
gridDim_x)])
kbody' :: KernelBody GPU
kbody' = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms_seggroup [KernelResult]
ret_seggroup
Stm GPU -> Builder GPU (Stm GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPU -> Builder GPU (Stm GPU))
-> Stm GPU -> Builder GPU (Stm GPU)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody GPU
-> SegOp SegLevel GPU
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
level' SegSpace
space' [TypeBase Shape NoUniqueness]
kertp KernelBody GPU
kbody'
Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU)))
-> Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a b. (a -> b) -> a -> b
$ (Stms GPU, Stm GPU) -> Maybe (Stms GPU, Stm GPU)
forall a. a -> Maybe a
Just (Stms GPU
host_stms, Stm GPU
new_kernel)
where
getResNm :: KernelResult -> Maybe VName
getResNm (Returns ResultManifest
ResultMaySimplify Certs
_ (Var VName
res_nm)) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
res_nm
getResNm KernelResult
_ = Maybe VName
forall a. Maybe a
Nothing
limitTile :: String -> SubExp -> SubExp -> Builder GPU SubExp
limitTile :: [Char]
-> SubExp -> SubExp -> BuilderT GPU (State VNameSource) SubExp
limitTile [Char]
t_str SubExp
t SubExp
d_K = [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
t_str (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int64) SubExp
t SubExp
d_K
insertTranspose ::
VarianceTable ->
(VName, SubExp) ->
(M.Map VName (Stm GPU), M.Map VName (PrimType, Stm GPU)) ->
(VName, Stm GPU) ->
Builder GPU (M.Map VName (Stm GPU), M.Map VName (PrimType, Stm GPU))
insertTranspose :: Map VName Names
-> (VName, SubExp)
-> (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
-> (VName, Stm GPU)
-> BuilderT
GPU
(State VNameSource)
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
insertTranspose Map VName Names
variance (VName
gidz, SubExp
_) (Map VName (Stm GPU)
tab_inn, Map VName (PrimType, Stm GPU)
tab_out) (VName
p_nm, stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
patt StmAux (ExpDec GPU)
yy (BasicOp (Index VName
arr_nm Slice SubExp
slc))))
| [PatElem (TypeBase Shape NoUniqueness)
p] <- Pat (TypeBase Shape NoUniqueness)
-> [PatElem (TypeBase Shape NoUniqueness)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (TypeBase Shape NoUniqueness)
Pat (LetDec GPU)
patt,
PrimType
ptp <- TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ PatElem (TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness
forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (TypeBase Shape NoUniqueness)
p,
VName
p_nm VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (TypeBase Shape NoUniqueness) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (TypeBase Shape NoUniqueness)
p =
case (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> [Int]
forall a. (a -> Bool) -> [a] -> [Int]
L.findIndices (Map VName Names -> VName -> DimIndex SubExp -> Bool
variantSliceDim Map VName Names
variance VName
gidz) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc) of
[] -> (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
-> BuilderT
GPU
(State VNameSource)
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Stm GPU -> Map VName (Stm GPU) -> Map VName (Stm GPU)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_nm Stm GPU
stm Map VName (Stm GPU)
tab_inn, Map VName (PrimType, Stm GPU)
tab_out)
Int
i : [Int]
_ -> do
TypeBase Shape NoUniqueness
arr_tp <- VName
-> BuilderT GPU (State VNameSource) (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr_nm
let perm :: [Int]
perm = [Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
arr_tp Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
i]
let arr_tr_str :: [Char]
arr_tr_str = VName -> [Char]
baseString VName
arr_nm [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_transp"
VName
arr_tr_nm <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
arr_tr_str (Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
arr_nm
let e_ind' :: Exp GPU
e_ind' = BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_tr_nm Slice SubExp
slc
let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
patt StmAux (ExpDec GPU)
yy Exp GPU
e_ind'
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
-> BuilderT
GPU
(State VNameSource)
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName (Stm GPU)
tab_inn, VName
-> (PrimType, Stm GPU)
-> Map VName (PrimType, Stm GPU)
-> Map VName (PrimType, Stm GPU)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_nm (PrimType
ptp, Stm GPU
stm') Map VName (PrimType, Stm GPU)
tab_out)
insertTranspose Map VName Names
_ (VName, SubExp)
_ (Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
_ (VName, Stm GPU)
_ = [Char]
-> BuilderT
GPU
(State VNameSource)
(Map VName (Stm GPU), Map VName (PrimType, Stm GPU))
forall a. HasCallStack => [Char] -> a
error [Char]
"\nUnreachable case reached in insertTranspose case, doRegTiling3D\n"
variantSliceDim :: VarianceTable -> VName -> DimIndex SubExp -> Bool
variantSliceDim :: Map VName Names -> VName -> DimIndex SubExp -> Bool
variantSliceDim Map VName Names
variance VName
gidz (DimFix (Var VName
vnm)) = Map VName Names -> VName -> VName -> Bool
variantToDim Map VName Names
variance VName
gidz VName
vnm
variantSliceDim Map VName Names
_ VName
_ DimIndex SubExp
_ = Bool
False
doRegTiling3D Stm GPU
_ = Maybe (Stms GPU, Stm GPU) -> TileM (Maybe (Stms GPU, Stm GPU))
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
forall a. Maybe a
Nothing