{-# LANGUAGE FlexibleContexts #-}

module Futhark.Optimise.TileLoops.Shared
  ( TileM,
    Env,
    index,
    update,
    forLoop',
    forLoop,
    segMap1D,
    segMap2D,
    segMap3D,
    segScatter2D,
    VarianceTable,
    varianceInStms,
    isTileableRedomap,
    changeEnv,
    TileKind (..),
  )
where

import Control.Monad.Reader
import Control.Monad.State
import Data.List (foldl', zip4)
import qualified Data.Map as M
import Futhark.IR.GPU
import qualified Futhark.IR.Mem.IxFun as IxFun
import qualified Futhark.IR.SeqMem as ExpMem
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename

type TileM = ReaderT (Scope GPU) (State VNameSource)

-- | Are we working with full or partial tiles?
data TileKind = TilePartial | TileFull

-- index an array with indices given in outer_indices; any inner
-- dims of arr not indexed by outer_indices are sliced entirely
index :: MonadBuilder m => String -> VName -> [VName] -> m VName
index :: String -> VName -> [VName] -> m VName
index String
se_desc VName
arr [VName]
outer_indices = do
  Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
  let shape :: Shape
shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t
      inner_dims :: [SubExp]
inner_dims = Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp]) -> Shape -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> Shape -> Shape
forall d. Int -> ShapeBase d -> ShapeBase d
stripDims ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
outer_indices) Shape
shape
      untouched :: SubExp -> DimIndex SubExp
untouched SubExp
d = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
      inner_slices :: [DimIndex SubExp]
inner_slices = (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
untouched [SubExp]
inner_dims
      slice :: Slice SubExp
slice = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
outer_indices [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [DimIndex SubExp]
inner_slices
  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
se_desc (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

update :: MonadBuilder m => String -> VName -> [VName] -> SubExp -> m VName
update :: String -> VName -> [VName] -> SubExp -> m VName
update String
se_desc VName
arr [VName]
indices SubExp
new_elem =
  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
se_desc (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
arr ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
indices) SubExp
new_elem

forLoop' ::
  SubExp -> -- loop var
  [VName] -> -- loop inits
  ( VName ->
    [VName] -> -- (loop var -> loop inits -> loop body)
    Builder GPU (Body GPU)
  ) ->
  Builder GPU [VName]
forLoop' :: SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
i_bound [VName]
merge VName -> [VName] -> Builder GPU (Body GPU)
body = do
  VName
i <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i" -- could give this as arg to the function
  let loop_form :: LoopForm rep
loop_form = VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
i_bound []

  [Type]
merge_ts <- (VName -> BuilderT GPU (State VNameSource) Type)
-> [VName] -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BuilderT GPU (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
merge
  [Param (TypeBase Shape Uniqueness)]
loop_inits <- (Type
 -> BuilderT
      GPU (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> [Type]
-> BuilderT
     GPU (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
merge_t -> String
-> TypeBase Shape Uniqueness
-> BuilderT
     GPU (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"merge" (TypeBase Shape Uniqueness
 -> BuilderT
      GPU (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> TypeBase Shape Uniqueness
-> BuilderT
     GPU (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall a b. (a -> b) -> a -> b
$ Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
merge_t Uniqueness
Unique) [Type]
merge_ts

  Body GPU
loop_body <-
    Builder GPU (Body GPU) -> Builder GPU (Body GPU)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder GPU (Body GPU) -> Builder GPU (Body GPU))
-> (Builder GPU (Body GPU) -> Builder GPU (Body GPU))
-> Builder GPU (Body GPU)
-> Builder GPU (Body GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopForm GPU -> Builder GPU (Body GPU) -> Builder GPU (Body GPU)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm GPU
forall rep. LoopForm rep
loop_form (Builder GPU (Body GPU) -> Builder GPU (Body GPU))
-> (Builder GPU (Body GPU) -> Builder GPU (Body GPU))
-> Builder GPU (Body GPU)
-> Builder GPU (Body GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPU -> Builder GPU (Body GPU) -> Builder GPU (Body GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
loop_inits) (Builder GPU (Body GPU) -> Builder GPU (Body GPU))
-> Builder GPU (Body GPU) -> Builder GPU (Body GPU)
forall a b. (a -> b) -> a -> b
$
      VName -> [VName] -> Builder GPU (Body GPU)
body VName
i ([VName] -> Builder GPU (Body GPU))
-> [VName] -> Builder GPU (Body GPU)
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_inits

  String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"loop" (Exp (Rep (BuilderT GPU (State VNameSource)))
 -> Builder GPU [VName])
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
    [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop ([Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_inits ([SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)])
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
merge) LoopForm GPU
forall rep. LoopForm rep
loop_form Body GPU
loop_body

forLoop ::
  SubExp ->
  [VName] ->
  (VName -> [VName] -> Builder GPU (Body GPU)) ->
  Builder GPU VName
forLoop :: SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> BuilderT GPU (State VNameSource) VName
forLoop SubExp
i_bound [VName]
merge VName -> [VName] -> Builder GPU (Body GPU)
body = do
  [VName]
res_list <- SubExp
-> [VName]
-> (VName -> [VName] -> Builder GPU (Body GPU))
-> Builder GPU [VName]
forLoop' SubExp
i_bound [VName]
merge VName -> [VName] -> Builder GPU (Body GPU)
body
  VName -> BuilderT GPU (State VNameSource) VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> BuilderT GPU (State VNameSource) VName)
-> VName -> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. [a] -> a
head [VName]
res_list

segMap1D ::
  String ->
  SegLevel ->
  ResultManifest ->
  (VName -> Builder GPU Result) ->
  Builder GPU [VName]
segMap1D :: String
-> SegLevel
-> ResultManifest
-> (VName -> Builder GPU Result)
-> Builder GPU [VName]
segMap1D String
desc SegLevel
lvl ResultManifest
manifest VName -> Builder GPU Result
f = do
  VName
ltid <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid"
  VName
ltid_flat <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)]

  (([Type]
ts, Result
res), Stms GPU
stms) <- Scope GPU
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> (Builder GPU ([Type], Result)
    -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU ([Type], Result)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- VName -> Builder GPU Result
f VName
ltid
    [Type]
ts <- (SubExpRes -> BuilderT GPU (State VNameSource) Type)
-> Result -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> BuilderT GPU (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res
    ([Type], Result) -> Builder GPU ([Type], Result)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ts, Result
res)
  Body BodyDec GPU
_ Stms GPU
stms' Result
res' <- Body GPU -> Builder GPU (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> Builder GPU (Body GPU))
-> Body GPU -> Builder GPU (Body GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
res

  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
  String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc (Exp (Rep (BuilderT GPU (State VNameSource)))
 -> Builder GPU [VName])
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
    HostOp GPU (SOAC GPU) -> Exp GPU
forall rep. Op rep -> Exp rep
Op (HostOp GPU (SOAC GPU) -> Exp GPU)
-> (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU
-> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> Exp GPU) -> SegOp SegLevel GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU -> SegOp SegLevel GPU
forall a b. (a -> b) -> a -> b
$ BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms' ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res'

segMap2D ::
  String -> -- desc
  SegLevel -> -- lvl
  ResultManifest -> -- manifest
  (SubExp, SubExp) -> -- (dim_x, dim_y)
  ( (VName, VName) -> -- f
    Builder GPU Result
  ) ->
  Builder GPU [VName]
segMap2D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap2D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_y, SubExp
dim_x) (VName, VName) -> Builder GPU Result
f = do
  VName
ltid_xx <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  VName
ltid_yy <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_flat <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_yy, SubExp
dim_y), (VName
ltid_xx, SubExp
dim_x)]

  (([Type]
ts, Result
res), Stms GPU
stms) <- Scope GPU
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) (BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> (Builder GPU ([Type], Result)
    -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU ([Type], Result)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- (VName, VName) -> Builder GPU Result
f (VName
ltid_yy, VName
ltid_xx)
    [Type]
ts <- (SubExpRes -> BuilderT GPU (State VNameSource) Type)
-> Result -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> BuilderT GPU (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res
    ([Type], Result) -> Builder GPU ([Type], Result)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ts, Result
res)

  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
  String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc (Exp GPU -> Builder GPU [VName])
-> (Exp GPU -> BuilderT GPU (State VNameSource) (Exp GPU))
-> Exp GPU
-> Builder GPU [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Exp GPU -> BuilderT GPU (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> Builder GPU [VName]) -> Exp GPU -> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
    HostOp GPU (SOAC GPU) -> Exp GPU
forall rep. Op rep -> Exp rep
Op (HostOp GPU (SOAC GPU) -> Exp GPU)
-> (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU
-> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> Exp GPU) -> SegOp SegLevel GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU -> SegOp SegLevel GPU
forall a b. (a -> b) -> a -> b
$ BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res

segMap3D ::
  String -> -- desc
  SegLevel -> -- lvl
  ResultManifest -> -- manifest
  (SubExp, SubExp, SubExp) -> -- (dim_z, dim_y, dim_x)
  ( (VName, VName, VName) -> -- f
    Builder GPU Result
  ) ->
  Builder GPU [VName]
segMap3D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Builder GPU Result)
-> Builder GPU [VName]
segMap3D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_z, SubExp
dim_y, SubExp
dim_x) (VName, VName, VName) -> Builder GPU Result
f = do
  VName
ltid_flat <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  VName
ltid_z <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_z"
  VName
ltid_y <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_x <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_z, SubExp
dim_z), (VName
ltid_y, SubExp
dim_y), (VName
ltid_x, SubExp
dim_x)]

  (([Type]
ts, Result
res), Stms GPU
stms) <- Scope GPU
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) (BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> (Builder GPU ([Type], Result)
    -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU ([Type], Result)
 -> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU))
-> Builder GPU ([Type], Result)
-> BuilderT GPU (State VNameSource) (([Type], Result), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- (VName, VName, VName) -> Builder GPU Result
f (VName
ltid_z, VName
ltid_y, VName
ltid_x)
    [Type]
ts <- (SubExpRes -> BuilderT GPU (State VNameSource) Type)
-> Result -> BuilderT GPU (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> BuilderT GPU (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res
    ([Type], Result) -> Builder GPU ([Type], Result)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ts, Result
res)

  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest Certs
cs SubExp
se
  String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc (Exp GPU -> Builder GPU [VName])
-> (Exp GPU -> BuilderT GPU (State VNameSource) (Exp GPU))
-> Exp GPU
-> Builder GPU [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Exp GPU -> BuilderT GPU (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> Builder GPU [VName]) -> Exp GPU -> Builder GPU [VName]
forall a b. (a -> b) -> a -> b
$
    HostOp GPU (SOAC GPU) -> Exp GPU
forall rep. Op rep -> Exp rep
Op (HostOp GPU (SOAC GPU) -> Exp GPU)
-> (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU
-> Exp GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> Exp GPU) -> SegOp SegLevel GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts (KernelBody GPU -> SegOp SegLevel GPU)
-> KernelBody GPU -> SegOp SegLevel GPU
forall a b. (a -> b) -> a -> b
$ BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res

segScatter2D ::
  String ->
  SubExp ->
  VName ->
  SegLevel -> -- lvl
  [SubExp] -> -- dims of sequential loop on top
  (SubExp, SubExp) -> -- (dim_y, dim_x)
  ([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)) -> -- f
  Builder GPU VName
segScatter2D :: String
-> SubExp
-> VName
-> SegLevel
-> [SubExp]
-> (SubExp, SubExp)
-> ([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp))
-> BuilderT GPU (State VNameSource) VName
segScatter2D String
desc SubExp
arr_size VName
updt_arr SegLevel
lvl [SubExp]
seq_dims (SubExp
dim_x, SubExp
dim_y) [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
f = do
  VName
ltid_flat <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  VName
ltid_y <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_x <- String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"

  [VName]
seq_is <- Int
-> BuilderT GPU (State VNameSource) VName -> Builder GPU [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
seq_dims) (String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_seq")
  let seq_space :: [(VName, SubExp)]
seq_space = [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
seq_is [SubExp]
seq_dims

  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
seq_space [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
ltid_y, SubExp
dim_y), (VName
ltid_x, SubExp
dim_x)]
      lvl' :: SegLevel
lvl' =
        Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
          (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
          (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
          (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [Int
0 .. [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
seq_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]))

  ((Type
t_v, SubExp
res_v, SubExp
res_i), Stms GPU
stms) <- Builder GPU (Type, SubExp, SubExp)
-> BuilderT
     GPU (State VNameSource) ((Type, SubExp, SubExp), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (Type, SubExp, SubExp)
 -> BuilderT
      GPU (State VNameSource) ((Type, SubExp, SubExp), Stms GPU))
-> Builder GPU (Type, SubExp, SubExp)
-> BuilderT
     GPU (State VNameSource) ((Type, SubExp, SubExp), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    (SubExp
res_v, SubExp
res_i) <-
      Scope GPU
-> Builder GPU (SubExp, SubExp) -> Builder GPU (SubExp, SubExp)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
segspace) (Builder GPU (SubExp, SubExp) -> Builder GPU (SubExp, SubExp))
-> Builder GPU (SubExp, SubExp) -> Builder GPU (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
        [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)
f [VName]
seq_is (VName
ltid_y, VName
ltid_x)
    Type
t_v <- SubExp -> BuilderT GPU (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
res_v
    (Type, SubExp, SubExp) -> Builder GPU (Type, SubExp, SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
t_v, SubExp
res_v, SubExp
res_i)

  let ret :: KernelResult
ret = Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
forall a. Monoid a => a
mempty ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
arr_size]) VName
updt_arr [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
res_i], SubExp
res_v)]
  let body :: KernelBody GPU
body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms [KernelResult
ret]

  String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
desc (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> (Exp GPU -> BuilderT GPU (State VNameSource) (Exp GPU))
-> Exp GPU
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Exp GPU -> BuilderT GPU (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> BuilderT GPU (State VNameSource) VName)
-> Exp GPU -> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
segspace [Type
t_v] KernelBody GPU
body

-- | The variance table keeps a mapping from a variable name
-- (something produced by a 'Stm') to the kernel thread indices
-- that name depends on.  If a variable is not present in this table,
-- that means it is bound outside the kernel (and so can be considered
-- invariant to all dimensions).
type VarianceTable = M.Map VName Names

isTileableRedomap ::
  Stm GPU ->
  Maybe
    ( SubExp,
      [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
    )
isTileableRedomap :: Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
stm
  | Op (OtherOp (Screma w arrs form)) <- Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm,
    Just ([Reduce GPU]
reds, Lambda GPU
map_lam) <- ScremaForm GPU -> Maybe ([Reduce GPU], Lambda GPU)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm GPU
form,
    Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
red_nes <- [Reduce GPU] -> Reduce GPU
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce GPU]
reds,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> (Param Type -> Type) -> Param Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> (Param Type -> Type) -> Param Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam,
    Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
red_lam, -- No mapout arrays.
    Bool -> Bool
not ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs),
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam =
      (SubExp, [VName],
 (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam))
  | Bool
otherwise =
      Maybe
  (SubExp, [VName],
   (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
forall a. Maybe a
Nothing

defVarianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
variance Stm GPU
stm =
  (VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
forall k. Ord k => Map k Names -> k -> Map k Names
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
  where
    add :: Map k Names -> k -> Map k Names
add Map k Names
variance' k
v = k -> Names -> Map k Names -> Map k Names
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
v Names
binding_variance Map k Names
variance'
    look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
    binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm)

-- just in case you need the Screma being treated differently than
-- by default; previously Cosmin had to enhance it when dealing with stream.
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm VarianceTable
v0 stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op (OtherOp Screma {})))
  | Just (SubExp
_, [VName]
arrs, (Commutativity
_, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam)) <- Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
isTileableRedomap Stm GPU
stm =
      let v :: VarianceTable
v = VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm GPU
stm
          red_ps :: [LParam GPU]
red_ps = Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam
          map_ps :: [LParam GPU]
map_ps = Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam
          card_red :: Int
card_red = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes
          acc_lam_f :: [Param Type]
acc_lam_f = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take (Int
card_red Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
          arr_lam_f :: [Param Type]
arr_lam_f = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop (Int
card_red Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
          stm_lam :: Stms GPU
stm_lam = Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
map_lam) Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
red_lam)

          f :: VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
vacc (VName
v_a, VName
v_fm, VName
v_fr_acc, VName
v_fr_var) =
            let vrc :: Names
vrc = VName -> Names
oneName VName
v_a Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v_a VarianceTable
vacc
                vacc' :: VarianceTable
vacc' = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fm Names
vrc VarianceTable
vacc
                vrc' :: Names
vrc' = VName -> Names
oneName VName
v_fm Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
vrc
             in VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_acc (VName -> Names
oneName VName
v_fr_var Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
vrc') (VarianceTable -> VarianceTable) -> VarianceTable -> VarianceTable
forall a b. (a -> b) -> a -> b
$ VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_var Names
vrc' VarianceTable
vacc'

          v' :: VarianceTable
v' =
            (VarianceTable -> (VName, VName, VName, VName) -> VarianceTable)
-> VarianceTable -> [(VName, VName, VName, VName)] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
v ([(VName, VName, VName, VName)] -> VarianceTable)
-> [(VName, VName, VName, VName)] -> VarianceTable
forall a b. (a -> b) -> a -> b
$
              [VName]
-> [VName] -> [VName] -> [VName] -> [(VName, VName, VName, VName)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [VName]
arrs ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_ps) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_lam_f) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
arr_lam_f)
       in VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
v' Stms GPU
stm_lam
varianceInStm VarianceTable
v0 Stm GPU
stm = VarianceTable -> Stm GPU -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm GPU
stm

varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms = (VarianceTable -> Stm GPU -> VarianceTable)
-> VarianceTable -> Stms GPU -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> Stm GPU -> VarianceTable
varianceInStm

----------------
---- Helpers for building the environment that binds array variable names to their index functions
----------------

type IxFun = IxFun.IxFun (TPrimExp Int64 VName)

-- | Map from array variable names to their corresponding index functions.
--   The info is not guaranteed to be exact, e.g., we assume ifs and loops
--   return arrays layed out in normalized (row-major) form in memory.
--   We only record aliasing statements, such as transposition, slice, etc.
type IxFnEnv = M.Map VName IxFun

type WithEnv = M.Map VName (Lambda GPU, [SubExp])

type Env = (WithEnv, IxFnEnv)

changeEnv :: Env -> VName -> Exp GPU -> TileM Env
changeEnv :: Env -> VName -> Exp GPU -> TileM Env
changeEnv (WithEnv
with_env, IxFnEnv
ixfn_env) VName
y Exp GPU
e = do
  WithEnv
with_env' <- WithEnv -> Exp GPU -> TileM WithEnv
changeWithEnv WithEnv
with_env Exp GPU
e
  IxFnEnv
ixfn_env' <- IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv
changeIxFnEnv IxFnEnv
ixfn_env VName
y Exp GPU
e
  Env -> TileM Env
forall (f :: * -> *) a. Applicative f => a -> f a
pure (WithEnv
with_env', IxFnEnv
ixfn_env')

changeWithEnv :: WithEnv -> Exp GPU -> TileM WithEnv
changeWithEnv :: WithEnv -> Exp GPU -> TileM WithEnv
changeWithEnv WithEnv
with_env (WithAcc [WithAccInput GPU]
accum_decs Lambda GPU
inner_lam) = do
  let bindings :: [(Lambda GPU, [SubExp])]
bindings = (WithAccInput GPU -> (Lambda GPU, [SubExp]))
-> [WithAccInput GPU] -> [(Lambda GPU, [SubExp])]
forall a b. (a -> b) -> [a] -> [b]
map WithAccInput GPU -> (Lambda GPU, [SubExp])
forall a b rep b.
(ShapeBase a, b, Maybe (Lambda rep, b)) -> (Lambda rep, b)
mapfun [WithAccInput GPU]
accum_decs
      par_tps :: [VName]
par_tps = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([(Lambda GPU, [SubExp])] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Lambda GPU, [SubExp])]
bindings) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
inner_lam
      with_env' :: WithEnv
with_env' = WithEnv -> WithEnv -> WithEnv
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union WithEnv
with_env (WithEnv -> WithEnv) -> WithEnv -> WithEnv
forall a b. (a -> b) -> a -> b
$ [(VName, (Lambda GPU, [SubExp]))] -> WithEnv
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (Lambda GPU, [SubExp]))] -> WithEnv)
-> [(VName, (Lambda GPU, [SubExp]))] -> WithEnv
forall a b. (a -> b) -> a -> b
$ [VName]
-> [(Lambda GPU, [SubExp])] -> [(VName, (Lambda GPU, [SubExp]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
par_tps [(Lambda GPU, [SubExp])]
bindings
  WithEnv -> TileM WithEnv
forall (f :: * -> *) a. Applicative f => a -> f a
pure WithEnv
with_env'
  where
    mapfun :: (ShapeBase a, b, Maybe (Lambda rep, b)) -> (Lambda rep, b)
mapfun (ShapeBase a
_, b
_, Maybe (Lambda rep, b)
Nothing) = String -> (Lambda rep, b)
forall a. HasCallStack => String -> a
error String
"What the hack is an accumulator without operator?"
    mapfun (ShapeBase a
shp, b
_, Just (Lambda rep
lam_inds, b
ne)) =
      let len_inds :: Int
len_inds = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([a] -> Int) -> [a] -> Int
forall a b. (a -> b) -> a -> b
$ ShapeBase a -> [a]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase a
shp
          lam_op :: Lambda rep
lam_op = Lambda rep
lam_inds {lambdaParams :: [LParam rep]
lambdaParams = Int -> [LParam rep] -> [LParam rep]
forall a. Int -> [a] -> [a]
drop Int
len_inds ([LParam rep] -> [LParam rep]) -> [LParam rep] -> [LParam rep]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam_inds}
       in (Lambda rep
lam_op, b
ne)
changeWithEnv WithEnv
with_env Exp GPU
_ = WithEnv -> TileM WithEnv
forall (f :: * -> *) a. Applicative f => a -> f a
pure WithEnv
with_env

composeIxfuns :: IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns :: IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x IxFun -> IxFun
ixf_fun =
  case VName -> IxFnEnv -> Maybe IxFun
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x IxFnEnv
env of
    Just IxFun
ixf -> IxFnEnv -> TileM IxFnEnv
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IxFnEnv -> TileM IxFnEnv) -> IxFnEnv -> TileM IxFnEnv
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> IxFnEnv -> IxFnEnv
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
y (IxFun -> IxFun
ixf_fun IxFun
ixf) IxFnEnv
env
    Maybe IxFun
Nothing -> do
      Type
tp <- VName -> ReaderT (Scope GPU) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
      case Type
tp of
        Array PrimType
_ptp Shape
shp NoUniqueness
_u -> do
          let shp' :: [TPrimExp Int64 VName]
shp' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
          IxFnEnv -> TileM IxFnEnv
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IxFnEnv -> TileM IxFnEnv) -> IxFnEnv -> TileM IxFnEnv
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> IxFnEnv -> IxFnEnv
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
y (IxFun -> IxFun
ixf_fun (IxFun -> IxFun) -> IxFun -> IxFun
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
shp') IxFnEnv
env
        Type
_ -> IxFnEnv -> TileM IxFnEnv
forall (f :: * -> *) a. Applicative f => a -> f a
pure IxFnEnv
env

changeIxFnEnv :: IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv
changeIxFnEnv :: IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Reshape ShapeChange SubExp
shp_chg VName
x)) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (IxFun -> ShapeChange (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
`IxFun.reshape` (DimChange SubExp -> DimChange (TPrimExp Int64 VName))
-> ShapeChange SubExp -> ShapeChange (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimChange SubExp -> DimChange (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
ExpMem.pe64) ShapeChange SubExp
shp_chg)
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Manifest [Int]
perm VName
x)) = do
  Type
tp <- VName -> ReaderT (Scope GPU) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
  case Type
tp of
    Array PrimType
_ptp Shape
shp NoUniqueness
_u -> do
      let shp' :: [TPrimExp Int64 VName]
shp' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
ExpMem.pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
      let ixfn :: IxFun
ixfn = IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute ([TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
shp') [Int]
perm
      IxFnEnv -> TileM IxFnEnv
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IxFnEnv -> TileM IxFnEnv) -> IxFnEnv -> TileM IxFnEnv
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> IxFnEnv -> IxFnEnv
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
y IxFun
ixfn IxFnEnv
env
    Type
_ -> String -> TileM IxFnEnv
forall a. HasCallStack => String -> a
error String
"In TileLoops/Shared.hs, changeIxFnEnv: manifest applied to a non-array!"
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Rearrange [Int]
perm VName
x)) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
`IxFun.permute` [Int]
perm)
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Rotate [SubExp]
rs VName
x)) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (IxFun -> [TPrimExp Int64 VName] -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> IxFun num
`IxFun.rotate` (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
ExpMem.pe64 [SubExp]
rs)
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Index VName
x Slice SubExp
slc)) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x (IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
`IxFun.slice` ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> [DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
ExpMem.pe64) ([DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)])
-> [DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc))
changeIxFnEnv IxFnEnv
env VName
y (BasicOp (Opaque OpaqueOp
_ (Var VName
x))) =
  IxFnEnv -> VName -> VName -> (IxFun -> IxFun) -> TileM IxFnEnv
composeIxfuns IxFnEnv
env VName
y VName
x IxFun -> IxFun
forall a. a -> a
id
changeIxFnEnv IxFnEnv
env VName
_ Exp GPU
_ = IxFnEnv -> TileM IxFnEnv
forall (f :: * -> *) a. Applicative f => a -> f a
pure IxFnEnv
env