module Futhark.Optimise.TileLoops.Shared
  ( TileM,
    segMap1D,
    segMap2D,
    segMap3D,
    segScatter2D,
    VarianceTable,
    varianceInStms,
    isTileableRedomap,
  )
where

import Control.Monad.Reader
import Control.Monad.State
import Data.List (foldl', zip4)
import qualified Data.Map as M
import Futhark.IR.GPU
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename

type TileM = ReaderT (Scope GPU) (State VNameSource)

segMap1D ::
  String ->
  SegLevel ->
  ResultManifest ->
  (VName -> Builder GPU Result) ->
  Builder GPU [VName]
segMap1D :: String
-> SegLevel
-> ResultManifest
-> (VName -> Builder GPU Result)
-> Builder GPU [VName]
segMap1D String
desc SegLevel
lvl ResultManifest
manifest VName -> Builder GPU Result
f = do
  VName
ltid <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid"
  VName
ltid_flat <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)]

  (([Type]
ts, Result
res), Stms GPU
stms) <- Scope GPU
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> (Builder GPU ([Type], Result)
    -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU ([Type], Result)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- VName -> Builder GPU Result
f VName
ltid
    [Type]
ts <- (SubExpRes -> BuilderT GPU (State VNameSource) Type)
-> Result -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> BuilderT GPU (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res
    ([Type], Result) -> Builder GPU ([Type], Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, Result
res)
  Body BodyDec GPU
_ Stms GPU
stms' Result
res' <- BodyT GPU -> BuilderT GPU (State VNameSource) (BodyT GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (BodyT GPU -> BuilderT GPU (State VNameSource) (BodyT GPU))
-> BodyT GPU -> BuilderT GPU (State VNameSource) (BodyT GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> BodyT GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
res

  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
  String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc (Exp (Rep (BuilderT GPU (State VNameSource)))
 -> Builder GPU [VName])
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
    HostOp GPU (SOAC GPU) -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (HostOp GPU (SOAC GPU) -> ExpT GPU)
-> (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU
-> ExpT GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> ExpT GPU) -> SegOp SegLevel GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$
      SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU -> SegOp SegLevel GPU
forall a b. (a -> b) -> a -> b
$ BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms' ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res'

segMap2D ::
  String -> -- desc
  SegLevel -> -- lvl
  ResultManifest -> -- manifest
  (SubExp, SubExp) -> -- (dim_x, dim_y)
  ( (VName, VName) -> -- f
    Builder GPU Result
  ) ->
  Builder GPU [VName]
segMap2D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_y, SubExp
dim_x) (VName, VName) -> Builder GPU Result
f = do
  VName
ltid_xx <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  VName
ltid_flat <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  VName
ltid_yy <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_yy, SubExp
dim_y), (VName
ltid_xx, SubExp
dim_x)]

  (([Type]
ts, Result
res), Stms GPU
stms) <- Scope GPU
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) (BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> (Builder GPU ([Type], Result)
    -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU ([Type], Result)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- (VName, VName) -> Builder GPU Result
f (VName
ltid_yy, VName
ltid_xx)
    [Type]
ts <- (SubExpRes -> BuilderT GPU (State VNameSource) Type)
-> Result -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> BuilderT GPU (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res
    ([Type], Result) -> Builder GPU ([Type], Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, Result
res)

  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
  String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc (ExpT GPU -> Builder GPU [VName])
-> (ExpT GPU -> BuilderT GPU (State VNameSource) (ExpT GPU))
-> ExpT GPU
-> Builder GPU [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT GPU -> BuilderT GPU (State VNameSource) (ExpT GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (ExpT GPU -> Builder GPU [VName])
-> ExpT GPU -> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
    HostOp GPU (SOAC GPU) -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (HostOp GPU (SOAC GPU) -> ExpT GPU)
-> (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU
-> ExpT GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> ExpT GPU) -> SegOp SegLevel GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$
      SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU -> SegOp SegLevel GPU
forall a b. (a -> b) -> a -> b
$ BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res

segMap3D ::
  String -> -- desc
  SegLevel -> -- lvl
  ResultManifest -> -- manifest
  (SubExp, SubExp, SubExp) -> -- (dim_z, dim_y, dim_x)
  ( (VName, VName, VName) -> -- f
    Builder GPU Result
  ) ->
  Builder GPU [VName]
segMap3D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap3D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_z, SubExp
dim_y, SubExp
dim_x) (VName, VName, VName) -> Builder GPU Result
f = do
  VName
ltid_x <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  VName
ltid_flat <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  VName
ltid_y <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_z <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_z"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_z, SubExp
dim_z), (VName
ltid_y, SubExp
dim_y), (VName
ltid_x, SubExp
dim_x)]

  (([Type]
ts, Result
res), Stms GPU
stms) <- Scope GPU
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) (BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> (Builder GPU ([Type], Result)
    -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU ([Type], Result)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- (VName, VName, VName) -> Builder GPU Result
f (VName
ltid_z, VName
ltid_y, VName
ltid_x)
    [Type]
ts <- (SubExpRes -> BuilderT GPU (State VNameSource) Type)
-> Result -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> BuilderT GPU (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res
    ([Type], Result) -> Builder GPU ([Type], Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, Result
res)

  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
  String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc (ExpT GPU -> Builder GPU [VName])
-> (ExpT GPU -> BuilderT GPU (State VNameSource) (ExpT GPU))
-> ExpT GPU
-> Builder GPU [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT GPU -> BuilderT GPU (State VNameSource) (ExpT GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (ExpT GPU -> Builder GPU [VName])
-> ExpT GPU -> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
    HostOp GPU (SOAC GPU) -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (HostOp GPU (SOAC GPU) -> ExpT GPU)
-> (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU
-> ExpT GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> ExpT GPU) -> SegOp SegLevel GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$
      SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU -> SegOp SegLevel GPU
forall a b. (a -> b) -> a -> b
$ BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res

segScatter2D ::
  String -> -- desc
  SubExp -> -- arr_size
  VName ->
  SegLevel -> -- lvl
  (SubExp, SubExp) -> -- (dim_y, dim_x)
  ((VName, VName) -> Builder GPU (SubExp, SubExp)) -> -- f
  Builder GPU [VName]
segScatter2D :: String
-> SubExp
-> VName
-> SegLevel
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU (SubExp, SubExp))
-> Builder GPU [VName]
segScatter2D String
desc SubExp
arr_size VName
updt_arr SegLevel
lvl (SubExp
dim_x, SubExp
dim_y) (VName, VName) -> Builder GPU (SubExp, SubExp)
f = do
  VName
ltid_x <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  VName
ltid_y <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_flat <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_x, SubExp
dim_x), (VName
ltid_y, SubExp
dim_y)]

  ((Type
t_v, SubExp
res_v, SubExp
res_i), Stms GPU
stms) <- Builder GPU (Type, SubExp, SubExp)
-> BuilderT
     GPU (State VNameSource) ((Type, SubExp, SubExp), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (Type, SubExp, SubExp)
 -> BuilderT
      GPU (State VNameSource) ((Type, SubExp, SubExp), Stms GPU))
-> Builder GPU (Type, SubExp, SubExp)
-> BuilderT
     GPU (State VNameSource) ((Type, SubExp, SubExp), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    (SubExp
res_v, SubExp
res_i) <- (VName, VName) -> Builder GPU (SubExp, SubExp)
f (VName
ltid_x, VName
ltid_y)
    Type
t_v <- SubExp -> BuilderT GPU (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
res_v
    (Type, SubExp, SubExp) -> Builder GPU (Type, SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t_v, SubExp
res_v, SubExp
res_i)

  let ret :: KernelResult
ret = Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
forall a. Monoid a => a
mempty ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
arr_size]) VName
updt_arr [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
res_i], SubExp
res_v)]
  let body :: KernelBody GPU
body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms [KernelResult
ret]

  String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc (ExpT GPU -> Builder GPU [VName])
-> (ExpT GPU -> BuilderT GPU (State VNameSource) (ExpT GPU))
-> ExpT GPU
-> Builder GPU [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT GPU -> BuilderT GPU (State VNameSource) (ExpT GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (ExpT GPU -> Builder GPU [VName])
-> ExpT GPU -> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$ Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
segspace [Type
t_v] KernelBody GPU
body

-- | The variance table keeps a mapping from a variable name
-- (something produced by a 'Stm') to the kernel thread indices
-- that name depends on.  If a variable is not present in this table,
-- that means it is bound outside the kernel (and so can be considered
-- invariant to all dimensions).
type VarianceTable = M.Map VName Names

isTileableRedomap ::
  Stm GPU ->
  Maybe
    ( SubExp,
      [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
    )
isTileableRedomap :: Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
stm
  | Op (OtherOp (Screma w arrs form)) <- Stm GPU -> ExpT GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm,
    Just ([Reduce GPU]
reds, Lambda GPU
map_lam) <- ScremaForm GPU -> Maybe ([Reduce GPU], Lambda GPU)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm GPU
form,
    Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
red_nes <- [Reduce GPU] -> Reduce GPU
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce GPU]
reds,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> (Param Type -> Type) -> Param Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> (Param Type -> Type) -> Param Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam,
    Lambda GPU -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda GPU
map_lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda GPU -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda GPU
red_lam, -- No mapout arrays.
    Bool -> Bool
not ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs),
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda GPU
map_lam,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam =
    (SubExp, [VName],
 (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam))
  | Bool
otherwise =
    Maybe
  (SubExp, [VName],
   (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
forall a. Maybe a
Nothing

defVarianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
variance Stm GPU
stm =
  (VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
forall k. Ord k => Map k Names -> k -> Map k Names
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT Type -> [VName]) -> PatT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat GPU
forall rep. Stm rep -> Pat rep
stmPat Stm GPU
stm
  where
    add :: Map k Names -> k -> Map k Names
add Map k Names
variance' k
v = k -> Names -> Map k Names -> Map k Names
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
v Names
binding_variance Map k Names
variance'
    look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
    binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm)

-- just in case you need the Screma being treated differently than
-- by default; previously Cosmin had to enhance it when dealing with stream.
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm VarianceTable
v0 stm :: Stm GPU
stm@(Let Pat GPU
_ StmAux (ExpDec GPU)
_ (Op (OtherOp Screma {})))
  | Just (SubExp
_, [VName]
arrs, (Commutativity
_, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam)) <- Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
stm =
    let v :: VarianceTable
v = VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm GPU
stm
        red_ps :: [LParam GPU]
red_ps = Lambda GPU -> [LParam GPU]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam
        map_ps :: [LParam GPU]
map_ps = Lambda GPU -> [LParam GPU]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam
        card_red :: Int
card_red = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes
        acc_lam_f :: [Param Type]
acc_lam_f = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take (Int
card_red Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
        arr_lam_f :: [Param Type]
arr_lam_f = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop (Int
card_red Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
        stm_lam :: Stms GPU
stm_lam = BodyT GPU -> Stms GPU
forall rep. BodyT rep -> Stms rep
bodyStms (Lambda GPU -> BodyT GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
map_lam) Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> BodyT GPU -> Stms GPU
forall rep. BodyT rep -> Stms rep
bodyStms (Lambda GPU -> BodyT GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
red_lam)

        f :: VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
vacc (VName
v_a, VName
v_fm, VName
v_fr_acc, VName
v_fr_var) =
          let vrc :: Names
vrc = VName -> Names
oneName VName
v_a Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v_a VarianceTable
vacc
              vacc' :: VarianceTable
vacc' = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fm Names
vrc VarianceTable
vacc
              vrc' :: Names
vrc' = VName -> Names
oneName VName
v_fm Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
vrc
           in VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_acc (VName -> Names
oneName VName
v_fr_var Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
vrc') (VarianceTable -> VarianceTable) -> VarianceTable -> VarianceTable
forall a b. (a -> b) -> a -> b
$ VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_var Names
vrc' VarianceTable
vacc'

        v' :: VarianceTable
v' =
          (VarianceTable -> (VName, VName, VName, VName) -> VarianceTable)
-> VarianceTable -> [(VName, VName, VName, VName)] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
v ([(VName, VName, VName, VName)] -> VarianceTable)
-> [(VName, VName, VName, VName)] -> VarianceTable
forall a b. (a -> b) -> a -> b
$
            [VName]
-> [VName] -> [VName] -> [VName] -> [(VName, VName, VName, VName)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [VName]
arrs ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_ps) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_lam_f) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
arr_lam_f)
     in VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
v' Stms GPU
stm_lam
varianceInStm VarianceTable
v0 Stm GPU
stm = VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm GPU
stm

varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms = (VarianceTable -> Stm GPU -> VarianceTable)
-> VarianceTable -> Stms GPU -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> Stm GPU -> VarianceTable
varianceInStm