module Futhark.Optimise.TileLoops.Shared
  ( TileM,
    Env,
    index,
    update,
    forLoop',
    forLoop,
    segMap1D,
    segMap2D,
    segMap3D,
    segScatter2D,
    VarianceTable,
    varianceInStms,
    isTileableRedomap,
    changeEnv,
    TileKind (..),
  )
where

import Control.Monad.Reader
import Control.Monad.State
import Data.List (foldl', zip4)
import Data.Map qualified as M
import Futhark.IR.GPU
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.SeqMem qualified as ExpMem
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename

type TileM = ReaderT (Scope GPU) (State VNameSource)

-- | Are we working with full or partial tiles?
data TileKind = TilePartial | TileFull

-- index an array with indices given in outer_indices; any inner
-- dims of arr not indexed by outer_indices are sliced entirely
index :: MonadBuilder m => String -> VName -> [VName] -> m VName
index :: forall (m :: * -> *).
MonadBuilder m =>
String -> VName -> [VName] -> m VName
index String
se_desc VName
arr [VName]
outer_indices = do
  Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
  let shape :: Shape
shape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t
      inner_dims :: [SubExp]
inner_dims = forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall d. Int -> ShapeBase d -> ShapeBase d
stripDims (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
outer_indices) Shape
shape
      untouched :: SubExp -> DimIndex SubExp
untouched SubExp
d = forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
      inner_slices :: [DimIndex SubExp]
inner_slices = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
untouched [SubExp]
inner_dims
      slice :: Slice SubExp
slice = forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
outer_indices forall a. [a] -> [a] -> [a]
++ [DimIndex SubExp]
inner_slices
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
se_desc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

update :: MonadBuilder m => String -> VName -> [VName] -> SubExp -> m VName
update :: forall (m :: * -> *).
MonadBuilder m =>
String -> VName -> [VName] -> SubExp -> m VName
update String
se_desc VName
arr [VName]
indices SubExp
new_elem =
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
se_desc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
arr (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
indices) SubExp
new_elem

forLoop' ::
  SubExp -> -- loop var
  [VName] -> -- loop inits
  ( VName ->
    [VName] -> -- (loop var -> loop inits -> loop body)
    Builder GPU (Body GPU)
  ) ->
  Builder GPU [VName]
forLoop' :: SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
i_bound [VName]
merge VName -> [VName] -> Builder GPU (Body GPU)
body = do
  VName
i <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i" -- could give this as arg to the function
  let loop_form :: LoopForm rep
loop_form = forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
i_bound []

  [Type]
merge_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
merge
  [Param (TypeBase Shape Uniqueness)]
loop_inits <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
merge_t -> forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"merge" forall a b. (a -> b) -> a -> b
$ forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
merge_t Uniqueness
Unique) [Type]
merge_ts

  Body GPU
loop_body <-
    forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf forall {k} {rep :: k}. LoopForm rep
loop_form forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
loop_inits) forall a b. (a -> b) -> a -> b
$
      VName -> [VName] -> Builder GPU (Body GPU)
body VName
i forall a b. (a -> b) -> a -> b
$
        forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_inits

  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"loop" forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_inits forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
merge) forall {k} {rep :: k}. LoopForm rep
loop_form Body GPU
loop_body

forLoop ::
  SubExp ->
  [VName] ->
  (VName -> [VName] -> Builder GPU (Body GPU)) ->
  Builder GPU VName
forLoop :: SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU VName
forLoop SubExp
i_bound [VName]
merge VName -> [VName] -> Builder GPU (Body GPU)
body = do
  [VName]
res_list <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
i_bound [VName]
merge VName -> [VName] -> Builder GPU (Body GPU)
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [VName]
res_list

segMap1D ::
  String ->
  SegLevel ->
  ResultManifest ->
  (VName -> Builder GPU Result) ->
  Builder GPU [VName]
segMap1D :: String
-> SegLevel
-> ResultManifest
-> (VName -> Builder GPU Result)
-> Builder GPU [VName]
segMap1D String
desc SegLevel
lvl ResultManifest
manifest VName -> Builder GPU Result
f = do
  VName
ltid <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid"
  VName
ltid_flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)]

  (([Type]
ts, Result
res), Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
    Result
res <- VName -> Builder GPU Result
f VName
ltid
    [Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
res
    forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ts, Result
res)
  Body BodyDec GPU
_ Stms GPU
stms' Result
res' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
res

  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k). Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$
      forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms' forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res'

segMap2D ::
  String -> -- desc
  SegLevel -> -- lvl
  ResultManifest -> -- manifest
  (SubExp, SubExp) -> -- (dim_x, dim_y)
  ( (VName, VName) -> -- f
    Builder GPU Result
  ) ->
  Builder GPU [VName]
segMap2D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_y, SubExp
dim_x) (VName, VName) -> Builder GPU Result
f = do
  VName
ltid_xx <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  VName
ltid_yy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_yy, SubExp
dim_y), (VName
ltid_xx, SubExp
dim_x)]

  (([Type]
ts, Result
res), Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
    Result
res <- (VName, VName) -> Builder GPU Result
f (VName
ltid_yy, VName
ltid_xx)
    [Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
res
    forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ts, Result
res)

  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k). Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$
      forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res

segMap3D ::
  String -> -- desc
  SegLevel -> -- lvl
  ResultManifest -> -- manifest
  (SubExp, SubExp, SubExp) -> -- (dim_z, dim_y, dim_x)
  ( (VName, VName, VName) -> -- f
    Builder GPU Result
  ) ->
  Builder GPU [VName]
segMap3D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap3D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_z, SubExp
dim_y, SubExp
dim_x) (VName, VName, VName) -> Builder GPU Result
f = do
  VName
ltid_flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  VName
ltid_z <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_z"
  VName
ltid_y <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_x <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_z, SubExp
dim_z), (VName
ltid_y, SubExp
dim_y), (VName
ltid_x, SubExp
dim_x)]

  (([Type]
ts, Result
res), Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
    Result
res <- (VName, VName, VName) -> Builder GPU Result
f (VName
ltid_z, VName
ltid_y, VName
ltid_x)
    [Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
res
    forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ts, Result
res)

  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k). Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$
      forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res

segScatter2D ::
  String ->
  SubExp ->
  VName ->
  SegLevel -> -- lvl
  [SubExp] -> -- dims of sequential loop on top
  (SubExp, SubExp) -> -- (dim_y, dim_x)
  ([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)) -> -- f
  Builder GPU VName
segScatter2D :: String
-> SubExp
-> VName
-> SegLevel
-> [SubExp]
-> (SubExp, SubExp)
-> ([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp))
-> Builder GPU VName
segScatter2D String
desc SubExp
arr_size VName
updt_arr SegLevel
lvl [SubExp]
seq_dims (SubExp
dim_x, SubExp
dim_y) [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
f = do
  VName
ltid_flat <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  VName
ltid_y <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_x <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"

  [VName]
seq_is <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
seq_dims) (forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_seq")
  let seq_space :: [(VName, SubExp)]
seq_space = forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
seq_is [SubExp]
seq_dims

  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
seq_space forall a. [a] -> [a] -> [a]
++ [(VName
ltid_y, SubExp
dim_y), (VName
ltid_x, SubExp
dim_x)]
      lvl' :: SegLevel
lvl' =
        Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
          (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
          (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
          (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
seq_dims forall a. Num a => a -> a -> a
- Int
1]))

  ((Type
t_v, SubExp
res_v, SubExp
res_i), Stms GPU
stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
    (SubExp
res_v, SubExp
res_i) <-
      forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) forall a b. (a -> b) -> a -> b
$
        [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
f [VName]
seq_is (VName
ltid_y, VName
ltid_x)
    Type
t_v <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
res_v
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
t_v, SubExp
res_v, SubExp
res_i)

  let ret :: KernelResult
ret = Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns forall a. Monoid a => a
mempty (forall d. [d] -> ShapeBase d
Shape [SubExp
arr_size]) VName
updt_arr [(forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix SubExp
res_i], SubExp
res_v)]
  let body :: KernelBody GPU
body = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms [KernelResult
ret]

  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
desc forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
segspace [Type
t_v] KernelBody GPU
body

-- | The variance table keeps a mapping from a variable name
-- (something produced by a 'Stm') to the kernel thread indices
-- that name depends on.  If a variable is not present in this table,
-- that means it is bound outside the kernel (and so can be considered
-- invariant to all dimensions).
type VarianceTable = M.Map VName Names

isTileableRedomap ::
  Stm GPU ->
  Maybe
    ( SubExp,
      [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
    )
isTileableRedomap :: Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
stm
  | Op (OtherOp (Screma SubExp
w [VName]
arrs ScremaForm GPU
form)) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm,
    Just ([Reduce GPU]
reds, Lambda GPU
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm GPU
form,
    Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
red_nes <- forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce GPU]
reds,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall shape u. TypeBase shape u -> Bool
primType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall shape u. TypeBase shape u -> Bool
primType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam,
    forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
red_lam, -- No mapout arrays.
    Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs),
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall shape u. TypeBase shape u -> Bool
primType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam =
      forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam))
  | Bool
otherwise =
      forall a. Maybe a
Nothing

defVarianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
variance Stm GPU
stm =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {k}. Ord k => Map k Names -> k -> Map k Names
add VarianceTable
variance forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
  where
    add :: Map k Names -> k -> Map k Names
add Map k Names
variance' k
v = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
v Names
binding_variance Map k Names
variance'
    look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v forall a. Semigroup a => a -> a -> a
<> forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
    binding_variance :: Names
binding_variance = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm)

-- just in case you need the Screma being treated differently than
-- by default; previously Cosmin had to enhance it when dealing with stream.
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm VarianceTable
v0 stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op (OtherOp Screma {})))
  | Just (SubExp
_, [VName]
arrs, (Commutativity
_, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam)) <- Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
stm =
      let v :: VarianceTable
v = VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm GPU
stm
          red_ps :: [LParam GPU]
red_ps = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam
          map_ps :: [LParam GPU]
map_ps = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam
          card_red :: Int
card_red = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes
          acc_lam_f :: [Param Type]
acc_lam_f = forall a. Int -> [a] -> [a]
take (Int
card_red forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
          arr_lam_f :: [Param Type]
arr_lam_f = forall a. Int -> [a] -> [a]
drop (Int
card_red forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
          stm_lam :: Stms GPU
stm_lam = forall {k} (rep :: k). Body rep -> Stms rep
bodyStms (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
map_lam) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Body rep -> Stms rep
bodyStms (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
red_lam)

          f :: VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
vacc (VName
v_a, VName
v_fm, VName
v_fr_acc, VName
v_fr_var) =
            let vrc :: Names
vrc = VName -> Names
oneName VName
v_a forall a. Semigroup a => a -> a -> a
<> forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
v_a VarianceTable
vacc
                vacc' :: VarianceTable
vacc' = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fm Names
vrc VarianceTable
vacc
                vrc' :: Names
vrc' = VName -> Names
oneName VName
v_fm forall a. Semigroup a => a -> a -> a
<> Names
vrc
             in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_acc (VName -> Names
oneName VName
v_fr_var forall a. Semigroup a => a -> a -> a
<> Names
vrc') forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_var Names
vrc' VarianceTable
vacc'

          v' :: VarianceTable
v' =
            forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
v forall a b. (a -> b) -> a -> b
$
              forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [VName]
arrs (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
map_ps) (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
acc_lam_f) (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
arr_lam_f)
       in VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
v' Stms GPU
stm_lam
varianceInStm VarianceTable
v0 Stm GPU
stm = VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm GPU
stm

varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> Stm GPU -> VarianceTable
varianceInStm

----------------
---- Helpers for building the environment that binds array variable names to their index functions
----------------

type IxFun = IxFun.IxFun (TPrimExp Int64 VName)

-- | Map from array variable names to their corresponding index functions.
--   The info is not guaranteed to be exact, e.g., we assume ifs and loops
--   return arrays layed out in normalized (row-major) form in memory.
--   We only record aliasing statements, such as transposition, slice, etc.
type IxFnEnv = M.Map VName IxFun

type WithEnv = M.Map VName (Lambda GPU, [SubExp])

type Env = (WithEnv, IxFnEnv)

changeEnv :: Env -> VName -> Exp GPU -> TileM Env
changeEnv :: Env -> VName -> Exp GPU -> TileM Env
changeEnv (WithEnv
with_env, IxFnEnv
ixfn_env) VName
y Exp GPU
e = do
  WithEnv
with_env' <- WithEnv -> Exp GPU -> TileM WithEnv
changeWithEnv WithEnv
with_env Exp GPU
e
  IxFnEnv
ixfn_env' <- IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv
changeIxFnEnv IxFnEnv
ixfn_env VName
y Exp GPU
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (WithEnv
with_env', IxFnEnv
ixfn_env')

changeWithEnv :: WithEnv -> Exp GPU -> TileM WithEnv
changeWithEnv :: WithEnv -> Exp GPU -> TileM WithEnv
changeWithEnv WithEnv
with_env (WithAcc [WithAccInput GPU]
accum_decs Lambda GPU
inner_lam) = do
  let bindings :: [(Lambda GPU, [SubExp])]
bindings = forall a b. (a -> b) -> [a] -> [b]
map forall {k} {a} {b} {rep :: k} {b}.
(ShapeBase a, b, Maybe (Lambda rep, b)) -> (Lambda rep, b)
mapfun [WithAccInput GPU]
accum_decs
      par_tps :: [VName]
par_tps = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Lambda GPU, [SubExp])]
bindings) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
inner_lam
      with_env' :: WithEnv
with_env' = forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union WithEnv
with_env forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
par_tps [(Lambda GPU, [SubExp])]
bindings
  forall (f :: * -> *) a. Applicative f => a -> f a
pure WithEnv
with_env'
  where
    mapfun :: (ShapeBase a, b, Maybe (Lambda rep, b)) -> (Lambda rep, b)
mapfun (ShapeBase a
_, b
_, Maybe (Lambda rep, b)
Nothing) = forall a. HasCallStack => String -> a
error String
"What the hack is an accumulator without operator?"
    mapfun (ShapeBase a
shp, b
_, Just (Lambda rep
lam_inds, b
ne)) =
      let len_inds :: Int
len_inds = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase a
shp
          lam_op :: Lambda rep
lam_op = Lambda rep
lam_inds {lambdaParams :: [LParam rep]
lambdaParams = forall a. Int -> [a] -> [a]
drop Int
len_inds forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam_inds}
       in (Lambda rep
lam_op, b
ne)
changeWithEnv WithEnv
with_env Exp GPU
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure WithEnv
with_env

composeIxfuns :: IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns :: IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x IxFun -> IxFun
ixf_fun =
  case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x IxFnEnv
env of
    Just IxFun
ixf -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
y (IxFun -> IxFun
ixf_fun IxFun
ixf) IxFnEnv
env
    Maybe IxFun
Nothing -> do
      Type
tp <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
x
      case Type
tp of
        Array PrimType
_ptp Shape
shp NoUniqueness
_u -> do
          let shp' :: [TPrimExp Int64 VName]
shp' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
y (IxFun -> IxFun
ixf_fun forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
shp') IxFnEnv
env
        Type
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure IxFnEnv
env

changeIxFnEnv :: IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv
changeIxFnEnv :: IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Reshape ReshapeKind
ReshapeArbitrary Shape
shp_chg VName
x)) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
`IxFun.reshape` forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
shp_chg))
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Reshape ReshapeKind
ReshapeCoerce Shape
shp_chg VName
x)) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
`IxFun.coerce` forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
shp_chg))
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Manifest [Int]
perm VName
x)) = do
  Type
tp <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
x
  case Type
tp of
    Array PrimType
_ptp Shape
shp NoUniqueness
_u -> do
      let shp' :: [TPrimExp Int64 VName]
shp' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
      let ixfn :: IxFun
ixfn = forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute (forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
shp') [Int]
perm
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
y IxFun
ixfn IxFnEnv
env
    Type
_ -> forall a. HasCallStack => String -> a
error String
"In TileLoops/Shared.hs, changeIxFnEnv: manifest applied to a non-array!"
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Rearrange [Int]
perm VName
x)) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
`IxFun.permute` [Int]
perm)
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Index VName
x Slice SubExp
slc)) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
`IxFun.slice` (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
ExpMem.pe64) forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc))
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Opaque OpaqueOp
_ (Var VName
x))) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x forall a. a -> a
id
changeIxFnEnv IxFnEnv
env VName
_ Exp GPU
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure IxFnEnv
env