module Futhark.Optimise.TileLoops.Shared
( TileM,
Env,
index,
update,
forLoop',
forLoop,
segMap1D,
segMap2D,
segMap3D,
segScatter2D,
VarianceTable,
varianceInStms,
isTileableRedomap,
changeEnv,
TileKind (..),
)
where
import Control.Monad.Reader
import Control.Monad.State
import Data.List (foldl', zip4)
import Data.Map qualified as M
import Futhark.IR.GPU
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.SeqMem qualified as ExpMem
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
type TileM = ReaderT (Scope GPU) (State VNameSource)
data TileKind = TilePartial | TileFull
index :: MonadBuilder m => String -> VName -> [VName] -> m VName
index :: forall (m :: * -> *).
MonadBuilder m =>
String -> VName -> [VName] -> m VName
index String
se_desc VName
arr [VName]
outer_indices = do
Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
let shape :: Shape
shape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t
inner_dims :: [SubExp]
inner_dims = forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall d. Int -> ShapeBase d -> ShapeBase d
stripDims (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
outer_indices) Shape
shape
untouched :: SubExp -> DimIndex SubExp
untouched SubExp
d = forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
inner_slices :: [DimIndex SubExp]
inner_slices = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
untouched [SubExp]
inner_dims
slice :: Slice SubExp
slice = forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
outer_indices forall a. [a] -> [a] -> [a]
++ [DimIndex SubExp]
inner_slices
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
se_desc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice
update :: MonadBuilder m => String -> VName -> [VName] -> SubExp -> m VName
update :: forall (m :: * -> *).
MonadBuilder m =>
String -> VName -> [VName] -> SubExp -> m VName
update String
se_desc VName
arr [VName]
indices SubExp
new_elem =
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
se_desc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
arr (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
indices) SubExp
new_elem
forLoop' ::
SubExp ->
[VName] ->
( VName ->
[VName] ->
Builder GPU (Body GPU)
) ->
Builder GPU [VName]
forLoop' :: SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
i_bound [VName]
merge VName -> [VName] -> Builder GPU (Body GPU)
body = do
VName
i <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
let loop_form :: LoopForm rep
loop_form = forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
i_bound []
[Type]
merge_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
merge
[Param (TypeBase Shape Uniqueness)]
loop_inits <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
merge_t -> forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"merge" forall a b. (a -> b) -> a -> b
$ forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
merge_t Uniqueness
Unique) [Type]
merge_ts
Body GPU
loop_body <-
forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf forall {k} {rep :: k}. LoopForm rep
loop_form forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
loop_inits) forall a b. (a -> b) -> a -> b
$
VName -> [VName] -> Builder GPU (Body GPU)
body VName
i forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_inits
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"loop" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_inits forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
merge) forall {k} {rep :: k}. LoopForm rep
loop_form Body GPU
loop_body
forLoop ::
SubExp ->
[VName] ->
(VName -> [VName] -> Builder GPU (Body GPU)) ->
Builder GPU VName
forLoop :: SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU VName
forLoop SubExp
i_bound [VName]
merge VName -> [VName] -> Builder GPU (Body GPU)
body = do
[VName]
res_list <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
i_bound [VName]
merge VName -> [VName] -> Builder GPU (Body GPU)
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [VName]
res_list
segMap1D ::
String ->
SegLevel ->
ResultManifest ->
(VName -> Builder GPU Result) ->
Builder GPU [VName]
segMap1D :: String
-> SegLevel
-> ResultManifest
-> (VName -> Builder GPU Result)
-> Builder GPU [VName]
segMap1D String
desc SegLevel
lvl ResultManifest
manifest VName -> Builder GPU Result
f = do
VName
ltid <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid"
VName
ltid_flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)]
(([Type]
ts, Result
res), Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
Result
res <- VName -> Builder GPU Result
f VName
ltid
[Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ts, Result
res)
Body BodyDec GPU
_ Stms GPU
stms' Result
res' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
res
let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$
forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms' forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res'
segMap2D ::
String ->
SegLevel ->
ResultManifest ->
(SubExp, SubExp) ->
( (VName, VName) ->
Builder GPU Result
) ->
Builder GPU [VName]
segMap2D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_y, SubExp
dim_x) (VName, VName) -> Builder GPU Result
f = do
VName
ltid_xx <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
VName
ltid_yy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
VName
ltid_flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_yy, SubExp
dim_y), (VName
ltid_xx, SubExp
dim_x)]
(([Type]
ts, Result
res), Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
Result
res <- (VName, VName) -> Builder GPU Result
f (VName
ltid_yy, VName
ltid_xx)
[Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ts, Result
res)
let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$
forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res
segMap3D ::
String ->
SegLevel ->
ResultManifest ->
(SubExp, SubExp, SubExp) ->
( (VName, VName, VName) ->
Builder GPU Result
) ->
Builder GPU [VName]
segMap3D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap3D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_z, SubExp
dim_y, SubExp
dim_x) (VName, VName, VName) -> Builder GPU Result
f = do
VName
ltid_flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
VName
ltid_z <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_z"
VName
ltid_y <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
VName
ltid_x <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_z, SubExp
dim_z), (VName
ltid_y, SubExp
dim_y), (VName
ltid_x, SubExp
dim_x)]
(([Type]
ts, Result
res), Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
Result
res <- (VName, VName, VName) -> Builder GPU Result
f (VName
ltid_z, VName
ltid_y, VName
ltid_x)
[Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ts, Result
res)
let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$
forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res
segScatter2D ::
String ->
SubExp ->
VName ->
SegLevel ->
[SubExp] ->
(SubExp, SubExp) ->
([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)) ->
Builder GPU VName
segScatter2D :: String
-> SubExp
-> VName
-> SegLevel
-> [SubExp]
-> (SubExp, SubExp)
-> ([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp))
-> Builder GPU VName
segScatter2D String
desc SubExp
arr_size VName
updt_arr SegLevel
lvl [SubExp]
seq_dims (SubExp
dim_x, SubExp
dim_y) [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
f = do
VName
ltid_flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
VName
ltid_y <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
VName
ltid_x <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
[VName]
seq_is <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
seq_dims) (forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_seq")
let seq_space :: [(VName, SubExp)]
seq_space = forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
seq_is [SubExp]
seq_dims
let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
seq_space forall a. [a] -> [a] -> [a]
++ [(VName
ltid_y, SubExp
dim_y), (VName
ltid_x, SubExp
dim_x)]
lvl' :: SegLevel
lvl' =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
(SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
(SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
(SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
seq_dims forall a. Num a => a -> a -> a
- Int
1]))
((Type
t_v, SubExp
res_v, SubExp
res_i), Stms GPU
stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
(SubExp
res_v, SubExp
res_i) <-
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) forall a b. (a -> b) -> a -> b
$
[VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
f [VName]
seq_is (VName
ltid_y, VName
ltid_x)
Type
t_v <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
res_v
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
t_v, SubExp
res_v, SubExp
res_i)
let ret :: KernelResult
ret = Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns forall a. Monoid a => a
mempty (forall d. [d] -> ShapeBase d
Shape [SubExp
arr_size]) VName
updt_arr [(forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix SubExp
res_i], SubExp
res_v)]
let body :: KernelBody GPU
body = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms [KernelResult
ret]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
desc forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
segspace [Type
t_v] KernelBody GPU
body
type VarianceTable = M.Map VName Names
isTileableRedomap ::
Stm GPU ->
Maybe
( SubExp,
[VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU)
)
isTileableRedomap :: Stm GPU
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
stm
| Op (OtherOp (Screma SubExp
w [VName]
arrs ScremaForm GPU
form)) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm,
Just ([Reduce GPU]
reds, Lambda GPU
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm GPU
form,
Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
red_nes <- forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce GPU]
reds,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall shape u. TypeBase shape u -> Bool
primType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall shape u. TypeBase shape u -> Bool
primType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam,
forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
red_lam,
Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs),
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall shape u. TypeBase shape u -> Bool
primType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam =
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam))
| Bool
otherwise =
forall a. Maybe a
Nothing
defVarianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
variance Stm GPU
stm =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {k}. Ord k => Map k Names -> k -> Map k Names
add VarianceTable
variance forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
where
add :: Map k Names -> k -> Map k Names
add Map k Names
variance' k
v = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
v Names
binding_variance Map k Names
variance'
look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v forall a. Semigroup a => a -> a -> a
<> forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
binding_variance :: Names
binding_variance = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm)
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm VarianceTable
v0 stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op (OtherOp Screma {})))
| Just (SubExp
_, [VName]
arrs, (Commutativity
_, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam)) <- Stm GPU
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
stm =
let v :: VarianceTable
v = VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm GPU
stm
red_ps :: [LParam GPU]
red_ps = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam
map_ps :: [LParam GPU]
map_ps = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam
card_red :: Int
card_red = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes
acc_lam_f :: [Param Type]
acc_lam_f = forall a. Int -> [a] -> [a]
take (Int
card_red forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
arr_lam_f :: [Param Type]
arr_lam_f = forall a. Int -> [a] -> [a]
drop (Int
card_red forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
stm_lam :: Stms GPU
stm_lam = forall {k} (rep :: k). Body rep -> Stms rep
bodyStms (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
map_lam) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Body rep -> Stms rep
bodyStms (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
red_lam)
f :: VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
vacc (VName
v_a, VName
v_fm, VName
v_fr_acc, VName
v_fr_var) =
let vrc :: Names
vrc = VName -> Names
oneName VName
v_a forall a. Semigroup a => a -> a -> a
<> forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
v_a VarianceTable
vacc
vacc' :: VarianceTable
vacc' = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fm Names
vrc VarianceTable
vacc
vrc' :: Names
vrc' = VName -> Names
oneName VName
v_fm forall a. Semigroup a => a -> a -> a
<> Names
vrc
in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_acc (VName -> Names
oneName VName
v_fr_var forall a. Semigroup a => a -> a -> a
<> Names
vrc') forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_var Names
vrc' VarianceTable
vacc'
v' :: VarianceTable
v' =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
v forall a b. (a -> b) -> a -> b
$
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [VName]
arrs (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
map_ps) (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
acc_lam_f) (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
arr_lam_f)
in VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
v' Stms GPU
stm_lam
varianceInStm VarianceTable
v0 Stm GPU
stm = VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm GPU
stm
varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> Stm GPU -> VarianceTable
varianceInStm
type IxFun = IxFun.IxFun (TPrimExp Int64 VName)
type IxFnEnv = M.Map VName IxFun
type WithEnv = M.Map VName (Lambda GPU, [SubExp])
type Env = (WithEnv, IxFnEnv)
changeEnv :: Env -> VName -> Exp GPU -> TileM Env
changeEnv :: Env -> VName -> Exp GPU -> TileM Env
changeEnv (WithEnv
with_env, IxFnEnv
ixfn_env) VName
y Exp GPU
e = do
WithEnv
with_env' <- WithEnv -> Exp GPU -> TileM WithEnv
changeWithEnv WithEnv
with_env Exp GPU
e
IxFnEnv
ixfn_env' <- IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv
changeIxFnEnv IxFnEnv
ixfn_env VName
y Exp GPU
e
forall (f :: * -> *) a. Applicative f => a -> f a
pure (WithEnv
with_env', IxFnEnv
ixfn_env')
changeWithEnv :: WithEnv -> Exp GPU -> TileM WithEnv
changeWithEnv :: WithEnv -> Exp GPU -> TileM WithEnv
changeWithEnv WithEnv
with_env (WithAcc [WithAccInput GPU]
accum_decs Lambda GPU
inner_lam) = do
let bindings :: [(Lambda GPU, [SubExp])]
bindings = forall a b. (a -> b) -> [a] -> [b]
map forall {k} {a} {b} {rep :: k} {b}.
(ShapeBase a, b, Maybe (Lambda rep, b)) -> (Lambda rep, b)
mapfun [WithAccInput GPU]
accum_decs
par_tps :: [VName]
par_tps = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Lambda GPU, [SubExp])]
bindings) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
inner_lam
with_env' :: WithEnv
with_env' = forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union WithEnv
with_env forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
par_tps [(Lambda GPU, [SubExp])]
bindings
forall (f :: * -> *) a. Applicative f => a -> f a
pure WithEnv
with_env'
where
mapfun :: (ShapeBase a, b, Maybe (Lambda rep, b)) -> (Lambda rep, b)
mapfun (ShapeBase a
_, b
_, Maybe (Lambda rep, b)
Nothing) = forall a. HasCallStack => String -> a
error String
"What the hack is an accumulator without operator?"
mapfun (ShapeBase a
shp, b
_, Just (Lambda rep
lam_inds, b
ne)) =
let len_inds :: Int
len_inds = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase a
shp
lam_op :: Lambda rep
lam_op = Lambda rep
lam_inds {lambdaParams :: [LParam rep]
lambdaParams = forall a. Int -> [a] -> [a]
drop Int
len_inds forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam_inds}
in (Lambda rep
lam_op, b
ne)
changeWithEnv WithEnv
with_env Exp GPU
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure WithEnv
with_env
composeIxfuns :: IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns :: IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x IxFun -> IxFun
ixf_fun =
case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x IxFnEnv
env of
Just IxFun
ixf -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
y (IxFun -> IxFun
ixf_fun IxFun
ixf) IxFnEnv
env
Maybe IxFun
Nothing -> do
Type
tp <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
x
case Type
tp of
Array PrimType
_ptp Shape
shp NoUniqueness
_u -> do
let shp' :: [TPrimExp Int64 VName]
shp' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
y (IxFun -> IxFun
ixf_fun forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
shp') IxFnEnv
env
Type
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure IxFnEnv
env
changeIxFnEnv :: IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv
changeIxFnEnv :: IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Reshape ReshapeKind
ReshapeArbitrary Shape
shp_chg VName
x)) =
IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
`IxFun.reshape` forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
shp_chg))
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Reshape ReshapeKind
ReshapeCoerce Shape
shp_chg VName
x)) =
IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
`IxFun.coerce` forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
shp_chg))
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Manifest [Int]
perm VName
x)) = do
Type
tp <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
x
case Type
tp of
Array PrimType
_ptp Shape
shp NoUniqueness
_u -> do
let shp' :: [TPrimExp Int64 VName]
shp' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
let ixfn :: IxFun
ixfn = forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute (forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
shp') [Int]
perm
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
y IxFun
ixfn IxFnEnv
env
Type
_ -> forall a. HasCallStack => String -> a
error String
"In TileLoops/Shared.hs, changeIxFnEnv: manifest applied to a non-array!"
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Rearrange [Int]
perm VName
x)) =
IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
`IxFun.permute` [Int]
perm)
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Index VName
x Slice SubExp
slc)) =
IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
`IxFun.slice` (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
ExpMem.pe64) forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc))
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Opaque OpaqueOp
_ (Var VName
x))) =
IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x forall a. a -> a
id
changeIxFnEnv IxFnEnv
env VName
_ Exp GPU
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure IxFnEnv
env