{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

-- | Perform a restricted form of loop tiling within SegMaps.  We only
-- tile primitive types, to avoid excessive local memory use.
module Futhark.Optimise.TileLoops (tileLoops) where

import Control.Monad.Reader
import Control.Monad.State
import Data.Map.Strict qualified as M
import Data.Maybe (mapMaybe)
import Data.Sequence qualified as Seq
import Futhark.Analysis.Alias qualified as Alias
import Futhark.IR.GPU
import Futhark.IR.Prop.Aliases (consumedInStm)
import Futhark.MonadFreshNames
import Futhark.Optimise.BlkRegTiling
import Futhark.Optimise.TileLoops.Shared
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)

-- | The pass definition.
tileLoops :: Pass GPU GPU
tileLoops :: Pass GPU GPU
tileLoops =
  forall {k} {k1} (fromrep :: k) (torep :: k1).
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"tile loops" [Char]
"Tile stream loops inside kernels" forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k).
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation forall {m :: * -> *}.
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
  where
    onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
      forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$
        forall s a. State s a -> s -> (a, s)
runState forall a b. (a -> b) -> a -> b
$
          forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Env -> Stms GPU -> TileM (Stms GPU)
optimiseStms (forall k a. Map k a
M.empty, forall k a. Map k a
M.empty) Stms GPU
stms) Scope GPU
scope

optimiseBody :: Env -> Body GPU -> TileM (Body GPU)
optimiseBody :: Env -> Body GPU -> TileM (Body GPU)
optimiseBody Env
env (Body () Stms GPU
stms Result
res) =
  forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> Stms GPU -> TileM (Stms GPU)
optimiseStms Env
env Stms GPU
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseStms :: Env -> Stms GPU -> TileM (Stms GPU)
optimiseStms :: Env -> Stms GPU -> TileM (Stms GPU)
optimiseStms Env
env Stms GPU
stms =
  forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) forall a b. (a -> b) -> a -> b
$ do
    (Env
_, Stms GPU
stms') <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
env, forall a. Monoid a => a
mempty) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms'
  where
    foldfun :: (Env, Stms GPU) -> Stm GPU -> TileM (Env, Stms GPU)
    foldfun :: (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
e, Stms GPU
ss) Stm GPU
s = do
      (Env
e', Stms GPU
s') <- Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
e Stm GPU
s
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
e', Stms GPU
ss forall a. Semigroup a => a -> a -> a
<> Stms GPU
s')

optimiseStm :: Env -> Stm GPU -> TileM (Env, Stms GPU)
optimiseStm :: Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
env stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap lvl :: SegLevel
lvl@SegThread {} SegSpace
space [Type]
ts KernelBody GPU
kbody)))) = do
  Maybe (Stms GPU, Stm GPU)
res3dtiling <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ Stm GPU
-> ReaderT
     (Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU))
doRegTiling3D Stm GPU
stm
  Stms GPU
stms' <-
    case Maybe (Stms GPU, Stm GPU)
res3dtiling of
      Just (Stms GPU
extra_stms, Stm GPU
stmt') -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
extra_stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
stmt')
      Maybe (Stms GPU, Stm GPU)
Nothing -> do
        Maybe (Stms GPU, Stm GPU)
blkRegTiling_res <- Env
-> Stm GPU
-> ReaderT
     (Scope GPU) (State VNameSource) (Maybe (Stms GPU, Stm GPU))
mmBlkRegTiling Env
env Stm GPU
stm
        case Maybe (Stms GPU, Stm GPU)
blkRegTiling_res of
          Just (Stms GPU
extra_stms, Stm GPU
stmt') -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
extra_stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
stmt')
          Maybe (Stms GPU, Stm GPU)
Nothing -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
            (Stms GPU
host_stms, (SegLevel
lvl', SegSpace
space', KernelBody GPU
kbody')) <- Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody GPU
-> TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
tileInKernelBody forall a. Monoid a => a
mempty AliasTable
initial_variance SegLevel
lvl SegSpace
space [Type]
ts KernelBody GPU
kbody
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms GPU
host_stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
space' [Type]
ts KernelBody GPU
kbody')
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env, Stms GPU
stms')
  where
    initial_variance :: AliasTable
initial_variance = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
optimiseStm Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
  Env
env' <- Env -> VName -> Exp GPU -> TileM Env
changeEnv Env
env (forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPU)
pat) Exp GPU
e
  Exp GPU
e' <- forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env') Exp GPU
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env', forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e')
  where
    optimise :: Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env' = forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper {mapOnBody :: Scope GPU -> Body GPU -> TileM (Body GPU)
mapOnBody = \Scope GPU
scope -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Body GPU -> TileM (Body GPU)
optimiseBody Env
env'}

tileInKernelBody ::
  Names ->
  VarianceTable ->
  SegLevel ->
  SegSpace ->
  [Type] ->
  KernelBody GPU ->
  TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
tileInKernelBody :: Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody GPU
-> TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU))
tileInKernelBody Names
branch_variant AliasTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts KernelBody GPU
kbody
  | Just Result
kbody_res <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> Maybe SubExpRes
isSimpleResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPU
kbody = do
      Maybe (Stms GPU, Tiling, TiledBody)
maybe_tiled <-
        Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> Body GPU
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
tileInBody Names
branch_variant AliasTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody) Result
kbody_res
      case Maybe (Stms GPU, Tiling, TiledBody)
maybe_tiled of
        Just (Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody) -> do
          ([KernelResult]
res', Stms GPU
stms') <-
            forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tiling -> VName -> Builder GPU KernelResult
tilingTileReturns Tiling
tiling) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TiledBody
tiledBody forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( Stms GPU
host_stms,
              ( Tiling -> SegLevel
tilingLevel Tiling
tiling,
                Tiling -> SegSpace
tilingSpace Tiling
tiling,
                forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
stms' [KernelResult]
res'
              )
            )
        Maybe (Stms GPU, Tiling, TiledBody)
Nothing ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody GPU
kbody))
  | Bool
otherwise =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody GPU
kbody))
  where
    isSimpleResult :: KernelResult -> Maybe SubExpRes
isSimpleResult (Returns ResultManifest
_ Certs
cs SubExp
se) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs SubExp
se
    isSimpleResult KernelResult
_ = forall a. Maybe a
Nothing

tileInBody ::
  Names ->
  VarianceTable ->
  SegLevel ->
  SegSpace ->
  [Type] ->
  Body GPU ->
  TileM (Maybe (Stms GPU, Tiling, TiledBody))
tileInBody :: Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> Body GPU
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
tileInBody Names
branch_variant AliasTable
initial_variance SegLevel
initial_lvl SegSpace
initial_space [Type]
res_ts (Body () Stms GPU
initial_kstms Result
stms_res) =
  Stms GPU
-> [Stm GPU] -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
descend forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPU
initial_kstms
  where
    variance :: AliasTable
variance = AliasTable -> Stms GPU -> AliasTable
varianceInStms AliasTable
initial_variance Stms GPU
initial_kstms

    descend :: Stms GPU
-> [Stm GPU] -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
descend Stms GPU
_ [] =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    descend Stms GPU
prestms (Stm GPU
stm_to_tile : [Stm GPU]
poststms)
      -- 2D tiling of redomap.
      | ([VName]
gtids, [SubExp]
kdims) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
        Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form) <- Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
tileable Stm GPU
stm_to_tile,
        Just [InputArray]
inputs <-
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Names -> AliasTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant AliasTable
variance [VName]
gtids) [VName]
arrs,
        Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
        VName
gtid_y : VName
gtid_x : [VName]
top_gtids_rev <- forall a. [a] -> [a]
reverse [VName]
gtids,
        SubExp
kdim_y : SubExp
kdim_x : [SubExp]
top_kdims_rev <- forall a. [a] -> [a]
reverse [SubExp]
kdims,
        (Stms GPU
prestms', Stms GPU
poststms') <-
          AliasTable
-> Stms GPU -> Stm GPU -> Stms GPU -> (Stms GPU, Stms GPU)
preludeToPostlude AliasTable
variance Stms GPU
prestms Stm GPU
stm_to_tile (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
poststms),
        Names
used <- forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm_to_tile forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Stms GPU
poststms' forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
          forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> (Stms GPU, Tiling, TiledBody)
injectPrelude SegSpace
initial_space AliasTable
variance Stms GPU
prestms' Names
used
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pat Type
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
-> [InputArray]
-> Stms GPU
-> Result
-> TileM (Stms GPU, Tiling, TiledBody)
tileGeneric
              ([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
top_gtids_rev [SubExp]
top_kdims_rev)
              SegLevel
initial_lvl
              [Type]
res_ts
              (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm_to_tile)
              (VName
gtid_x, VName
gtid_y)
              (SubExp
kdim_x, SubExp
kdim_y)
              SubExp
w
              (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form
              [InputArray]
inputs
              Stms GPU
poststms'
              Result
stms_res
      -- 1D tiling of redomap.
      | (VName
gtid, SubExp
kdim) : [(VName, SubExp)]
top_space_rev <- forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
        Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form) <- Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
tileable Stm GPU
stm_to_tile,
        [InputArray]
inputs <- forall a b. (a -> b) -> [a] -> [b]
map (VName -> AliasTable -> VName -> InputArray
is1DTileable VName
gtid AliasTable
variance) [VName]
arrs,
        Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
        VName
gtid VName -> Names -> Bool
`notNameIn` Names
branch_variant,
        (Stms GPU
prestms', Stms GPU
poststms') <-
          AliasTable
-> Stms GPU -> Stm GPU -> Stms GPU -> (Stms GPU, Stms GPU)
preludeToPostlude AliasTable
variance Stms GPU
prestms Stm GPU
stm_to_tile (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
poststms),
        Names
used <- forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm_to_tile forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Stms GPU
poststms' forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
          forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> (Stms GPU, Tiling, TiledBody)
injectPrelude SegSpace
initial_space AliasTable
variance Stms GPU
prestms' Names
used
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pat Type
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
-> [InputArray]
-> Stms GPU
-> Result
-> TileM (Stms GPU, Tiling, TiledBody)
tileGeneric
              ([(VName, SubExp)] -> DoTiling VName SubExp
tiling1d forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [(VName, SubExp)]
top_space_rev)
              SegLevel
initial_lvl
              [Type]
res_ts
              (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm_to_tile)
              VName
gtid
              SubExp
kdim
              SubExp
w
              (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form
              [InputArray]
inputs
              Stms GPU
poststms'
              Result
stms_res
      -- Tiling inside for-loop.
      | DoLoop [(FParam GPU, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
bound []) Body GPU
loopbody <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm_to_tile,
        Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn [(FParam GPU, SubExp)]
merge) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam GPU, SubExp)]
merge,
        (Stms GPU
prestms', Stms GPU
poststms') <-
          AliasTable
-> Stms GPU -> Stm GPU -> Stms GPU -> (Stms GPU, Stms GPU)
preludeToPostlude AliasTable
variance Stms GPU
prestms Stm GPU
stm_to_tile (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
poststms) = do
          let branch_variant' :: Names
branch_variant' =
                Names
branch_variant
                  forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat
                    ( forall a b. (a -> b) -> [a] -> [b]
map
                        (forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty) AliasTable
variance)
                        (Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn SubExp
bound))
                    )
              merge_params :: [Param (TypeBase Shape Uniqueness)]
merge_params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam GPU, SubExp)]
merge

          Maybe (Stms GPU, Tiling, TiledBody)
maybe_tiled <-
            forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
it) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
merge_params)
              forall a b. (a -> b) -> a -> b
$ Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> Body GPU
-> TileM (Maybe (Stms GPU, Tiling, TiledBody))
tileInBody
                Names
branch_variant'
                AliasTable
variance
                SegLevel
initial_lvl
                SegSpace
initial_space
                (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
merge_params)
              forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
loopbody) (forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
loopbody)

          case Maybe (Stms GPU, Tiling, TiledBody)
maybe_tiled of
            Maybe (Stms GPU, Tiling, TiledBody)
Nothing -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
next
            Just (Stms GPU, Tiling, TiledBody)
tiled ->
              forall a. a -> Maybe a
Just
                forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> [Type]
-> Pat Type
-> StmAux (ExpDec GPU)
-> [(FParam GPU, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms GPU
-> Result
-> TileM (Stms GPU, Tiling, TiledBody)
tileDoLoop
                  SegSpace
initial_space
                  AliasTable
variance
                  Stms GPU
prestms'
                  (forall a. FreeIn a => a -> Names
freeIn Body GPU
loopbody forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn [(FParam GPU, SubExp)]
merge)
                  (Stms GPU, Tiling, TiledBody)
tiled
                  [Type]
res_ts
                  (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm_to_tile)
                  (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm_to_tile)
                  [(FParam GPU, SubExp)]
merge
                  VName
i
                  IntType
it
                  SubExp
bound
                  Stms GPU
poststms'
                  Result
stms_res
      | Bool
otherwise = TileM (Maybe (Stms GPU, Tiling, TiledBody))
next
      where
        next :: TileM (Maybe (Stms GPU, Tiling, TiledBody))
next =
          forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stm GPU
stm_to_tile) forall a b. (a -> b) -> a -> b
$
            Stms GPU
-> [Stm GPU] -> TileM (Maybe (Stms GPU, Tiling, TiledBody))
descend (Stms GPU
prestms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
stm_to_tile) [Stm GPU]
poststms

-- | Move statements from prelude to postlude if they are not used in
-- the tiled statement anyway.
preludeToPostlude ::
  VarianceTable ->
  Stms GPU ->
  Stm GPU ->
  Stms GPU ->
  (Stms GPU, Stms GPU)
preludeToPostlude :: AliasTable
-> Stms GPU -> Stm GPU -> Stms GPU -> (Stms GPU, Stms GPU)
preludeToPostlude AliasTable
variance Stms GPU
prelude Stm GPU
stm_to_tile Stms GPU
postlude =
  (Stms GPU
prelude_used, Stms GPU
prelude_not_used forall a. Semigroup a => a -> a -> a
<> Stms GPU
postlude)
  where
    used_in_tiled :: Names
used_in_tiled = forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm_to_tile

    used_in_stm_variant :: Names
used_in_stm_variant =
      (Names
used_in_tiled <>) forall a b. (a -> b) -> a -> b
$
        forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty) AliasTable
variance) forall a b. (a -> b) -> a -> b
$
            Names -> [VName]
namesToList Names
used_in_tiled

    used :: Stm GPU -> Bool
used Stm GPU
stm =
      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_in_stm_variant) forall a b. (a -> b) -> a -> b
$
        forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm

    (Stms GPU
prelude_used, Stms GPU
prelude_not_used) =
      forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm GPU -> Bool
used Stms GPU
prelude

-- | Partition prelude statements preceding a tiled loop (or something
-- containing a tiled loop) into three categories:
--
-- 1) Group-level statements that are invariant to the threads in the group.
--
-- 2) Thread-variant statements that should be computed once with a segmap_thread_scalar.
--
-- 3) Thread-variant statements that should be recomputed whenever
-- they are needed.
--
-- The third category duplicates computation, so we only want to do it
-- when absolutely necessary.  Currently, this is necessary for
-- results that are views of an array (slicing, rotate, etc) and which
-- results are used after the prelude, because these cannot be
-- efficiently represented by a scalar segmap (they'll be manifested
-- in memory).
partitionPrelude ::
  VarianceTable ->
  Stms GPU ->
  Names ->
  Names ->
  (Stms GPU, Stms GPU, Stms GPU)
partitionPrelude :: AliasTable
-> Stms GPU -> Names -> Names -> (Stms GPU, Stms GPU, Stms GPU)
partitionPrelude AliasTable
variance Stms GPU
prestms Names
private Names
used_after =
  (Stms GPU
invariant_prestms, Stms GPU
precomputed_variant_prestms, Stms GPU
recomputed_variant_prestms)
  where
    invariantTo :: Names -> Stm GPU -> Bool
invariantTo Names
names Stm GPU
stm =
      case forall dec. Pat dec -> [VName]
patNames (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) of
        [] -> Bool
True -- Does not matter.
        VName
v : [VName]
_ -> forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
names) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
v AliasTable
variance)

    consumed :: VName -> Bool
consumed VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
consumed_in_prestms
    consumedStm :: Stm GPU -> Bool
consumedStm Stm GPU
stm = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
consumed (forall dec. Pat dec -> [VName]
patNames (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm))

    later_consumed :: Names
later_consumed =
      [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat) forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$
            forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm GPU -> Bool
consumedStm Stms GPU
prestms

    groupInvariant :: Stm GPU -> Bool
groupInvariant Stm GPU
stm =
      Names -> Stm GPU -> Bool
invariantTo Names
private Stm GPU
stm
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Names
later_consumed) (forall dec. Pat dec -> [VName]
patNames (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm))
        Bool -> Bool -> Bool
&& Names -> Stm GPU -> Bool
invariantTo Names
later_consumed Stm GPU
stm
    (Stms GPU
invariant_prestms, Stms GPU
variant_prestms) =
      forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm GPU -> Bool
groupInvariant Stms GPU
prestms

    consumed_in_prestms :: Names
consumed_in_prestms =
      forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {k} (rep :: k). Aliased rep => Stm rep -> Names
consumedInStm forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms forall a. Monoid a => a
mempty Stms GPU
prestms

    mustBeInlinedExp :: Exp rep -> Bool
mustBeInlinedExp (BasicOp (Index VName
_ Slice SubExp
slice)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
    mustBeInlinedExp (BasicOp Iota {}) = Bool
True
    mustBeInlinedExp (BasicOp Rotate {}) = Bool
True
    mustBeInlinedExp (BasicOp Rearrange {}) = Bool
True
    mustBeInlinedExp (BasicOp Reshape {}) = Bool
True
    mustBeInlinedExp Exp rep
_ = Bool
False
    mustBeInlined :: Stm GPU -> Bool
mustBeInlined Stm GPU
stm =
      forall {k} {rep :: k}. Exp rep -> Bool
mustBeInlinedExp (forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm)
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_after) (forall dec. Pat dec -> [VName]
patNames (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm))

    must_be_inlined :: Names
must_be_inlined =
      [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat) forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$
            forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm GPU -> Bool
mustBeInlined Stms GPU
variant_prestms
    recompute :: Stm GPU -> Bool
recompute Stm GPU
stm =
      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
must_be_inlined) (forall dec. Pat dec -> [VName]
patNames (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm))
        Bool -> Bool -> Bool
|| Bool -> Bool
not (Names -> Stm GPU -> Bool
invariantTo Names
must_be_inlined Stm GPU
stm)
    (Stms GPU
recomputed_variant_prestms, Stms GPU
precomputed_variant_prestms) =
      forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm GPU -> Bool
recompute Stms GPU
variant_prestms

-- Anything that is variant to the "private" names should be
-- considered thread-local.
injectPrelude ::
  SegSpace ->
  VarianceTable ->
  Stms GPU ->
  Names ->
  (Stms GPU, Tiling, TiledBody) ->
  (Stms GPU, Tiling, TiledBody)
injectPrelude :: SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> (Stms GPU, Tiling, TiledBody)
injectPrelude SegSpace
initial_space AliasTable
variance Stms GPU
prestms Names
used (Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody) =
  (Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody')
  where
    tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = do
      let nontiled :: (VName, SubExp) -> Bool
nontiled = (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling))
          private' :: Names
private' =
            Names
private
              forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst (forall a. (a -> Bool) -> [a] -> [a]
filter (VName, SubExp) -> Bool
nontiled forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space))
          ( Stms GPU
invariant_prestms,
            Stms GPU
precomputed_variant_prestms,
            Stms GPU
recomputed_variant_prestms
            ) =
              AliasTable
-> Stms GPU -> Names -> Names -> (Stms GPU, Stms GPU, Stms GPU)
partitionPrelude AliasTable
variance Stms GPU
prestms Names
private' Names
used

      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
invariant_prestms

      let live_set :: [VName]
live_set =
            Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$
              forall a. FreeIn a => Stms GPU -> a -> Names
liveSet Stms GPU
precomputed_variant_prestms forall a b. (a -> b) -> a -> b
$
                Names
used forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Stms GPU
recomputed_variant_prestms
      [VName]
prelude_arrs <-
        forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPU
precomputed_variant_prestms forall a b. (a -> b) -> a -> b
$
          Tiling
-> PrivStms
-> Stms GPU
-> [VName]
-> BuilderT GPU (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms GPU
precomputed_variant_prestms [VName]
live_set

      let prelude_privstms :: PrivStms
prelude_privstms =
            Stms GPU -> ReadPrelude -> PrivStms
PrivStms Stms GPU
recomputed_variant_prestms forall a b. (a -> b) -> a -> b
$
              [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set

      TiledBody
tiledBody Names
private' (PrivStms
prelude_privstms forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms)

tileDoLoop ::
  SegSpace ->
  VarianceTable ->
  Stms GPU ->
  Names ->
  (Stms GPU, Tiling, TiledBody) ->
  [Type] ->
  Pat Type ->
  StmAux (ExpDec GPU) ->
  [(FParam GPU, SubExp)] ->
  VName ->
  IntType ->
  SubExp ->
  Stms GPU ->
  Result ->
  TileM (Stms GPU, Tiling, TiledBody)
tileDoLoop :: SegSpace
-> AliasTable
-> Stms GPU
-> Names
-> (Stms GPU, Tiling, TiledBody)
-> [Type]
-> Pat Type
-> StmAux (ExpDec GPU)
-> [(FParam GPU, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms GPU
-> Result
-> TileM (Stms GPU, Tiling, TiledBody)
tileDoLoop SegSpace
initial_space AliasTable
variance Stms GPU
prestms Names
used_in_body (Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody) [Type]
res_ts Pat Type
pat StmAux (ExpDec GPU)
aux [(FParam GPU, SubExp)]
merge VName
i IntType
it SubExp
bound Stms GPU
poststms Result
poststms_res = do
  let prestms_used :: Names
prestms_used = Names
used_in_body forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Stms GPU
poststms forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Result
poststms_res
      ( Stms GPU
invariant_prestms,
        Stms GPU
precomputed_variant_prestms,
        Stms GPU
recomputed_variant_prestms
        ) =
          AliasTable
-> Stms GPU -> Names -> Names -> (Stms GPU, Stms GPU, Stms GPU)
partitionPrelude AliasTable
variance Stms GPU
prestms Names
tiled_kdims Names
prestms_used

  let ([Param (TypeBase Shape Uniqueness)]
mergeparams, [SubExp]
mergeinits) = forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPU, SubExp)]
merge

      -- Expand the loop merge parameters to be arrays.
      tileDim :: TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
t = forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t (Tiling -> Shape
tilingTileShape Tiling
tiling) forall a b. (a -> b) -> a -> b
$ forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness TypeBase Shape Uniqueness
t

      merge_scope :: Scope GPU
merge_scope = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
it) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams

      tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
host_stms forall a. Semigroup a => a -> a -> a
<> Scope GPU
merge_scope) forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
invariant_prestms

        let live_set :: [VName]
live_set =
              Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$
                forall a. FreeIn a => Stms GPU -> a -> Names
liveSet Stms GPU
precomputed_variant_prestms forall a b. (a -> b) -> a -> b
$
                  forall a. FreeIn a => a -> Names
freeIn Stms GPU
recomputed_variant_prestms forall a. Semigroup a => a -> a -> a
<> Names
prestms_used

        [VName]
prelude_arrs <-
          forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPU
precomputed_variant_prestms forall a b. (a -> b) -> a -> b
$
            Tiling
-> PrivStms
-> Stms GPU
-> [VName]
-> BuilderT GPU (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms GPU
precomputed_variant_prestms [VName]
live_set

        [Param (TypeBase Shape Uniqueness)]
mergeparams' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape Uniqueness)]
mergeparams forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
pname TypeBase Shape Uniqueness
pt) ->
          forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
pname forall a. [a] -> [a] -> [a]
++ [Char]
"_group") forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
pt)

        let merge_ts :: [Type]
merge_ts = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
mergeparams

        let inloop_privstms :: PrivStms
inloop_privstms =
              Stms GPU -> ReadPrelude -> PrivStms
PrivStms Stms GPU
recomputed_variant_prestms forall a b. (a -> b) -> a -> b
$
                [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set

        [SubExp]
mergeinit' <-
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec GPU)
aux) forall a b. (a -> b) -> a -> b
$
              Tiling
-> [Char]
-> SegLevel
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap Tiling
tiling [Char]
"tiled_loopinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate forall a b. (a -> b) -> a -> b
$
                \PrimExp VName
in_bounds [DimIndex SubExp]
slice ->
                  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$
                    [Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
"loopinit" PrimExp VName
in_bounds [Type]
merge_ts forall a b. (a -> b) -> a -> b
$ do
                      [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [DimIndex SubExp]
slice PrivStms
inloop_privstms
                      [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
                      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
mergeinits

        let merge' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
merge' = forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams' [SubExp]
mergeinit'

        let indexMergeParams :: ReadPrelude
indexMergeParams [DimIndex SubExp]
slice =
              forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams') forall a b. (a -> b) -> a -> b
$
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams [Param (TypeBase Shape Uniqueness)]
mergeparams') forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape Uniqueness)
to, Param (TypeBase Shape Uniqueness)
from) ->
                  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
to] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index (forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
from) forall a b. (a -> b) -> a -> b
$
                    Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall dec. Typed dec => Param dec -> Type
paramType Param (TypeBase Shape Uniqueness)
from) [DimIndex SubExp]
slice

            private' :: Names
private' =
              Names
private forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams')

            privstms' :: PrivStms
privstms' =
              Stms GPU -> ReadPrelude -> PrivStms
PrivStms forall a. Monoid a => a
mempty ReadPrelude
indexMergeParams forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms

        Body GPU
loopbody' <-
          forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams') forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). Buildable rep => [SubExp] -> Body rep
resultBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TiledBody
tiledBody Names
private' PrivStms
privstms'
        [VName]
accs' <-
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"tiled_inside_loop" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase Shape Uniqueness), SubExp)]
merge' (forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound []) Body GPU
loopbody'

        Tiling
-> PrivStms
-> Pat Type
-> [VName]
-> Stms GPU
-> Result
-> [Type]
-> BuilderT GPU (State VNameSource) [VName]
postludeGeneric Tiling
tiling (PrivStms
privstms forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms) Pat Type
pat [VName]
accs' Stms GPU
poststms Result
poststms_res [Type]
res_ts

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
host_stms, Tiling
tiling, TiledBody
tiledBody')
  where
    tiled_kdims :: Names
tiled_kdims =
      [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$
        forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
          forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling)) forall a b. (a -> b) -> a -> b
$
            SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space

doPrelude :: Tiling -> PrivStms -> Stms GPU -> [VName] -> Builder GPU [VName]
doPrelude :: Tiling
-> PrivStms
-> Stms GPU
-> [VName]
-> BuilderT GPU (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms GPU
prestms [VName]
prestms_live =
  -- Create a SegMap that takes care of the prelude for every thread.
  Tiling
-> [Char]
-> SegLevel
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap Tiling
tiling [Char]
"prelude" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate forall a b. (a -> b) -> a -> b
$
    \PrimExp VName
in_bounds [DimIndex SubExp]
slice -> do
      [Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
prestms_live
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
"pre" PrimExp VName
in_bounds [Type]
ts forall a b. (a -> b) -> a -> b
$ do
        [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
        forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
prestms
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
prestms_live

liveSet :: FreeIn a => Stms GPU -> a -> Names
liveSet :: forall a. FreeIn a => Stms GPU -> a -> Names
liveSet Stms GPU
stms a
after =
  [VName] -> Names
namesFromList (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat) Stms GPU
stms)
    Names -> Names -> Names
`namesIntersection` forall a. FreeIn a => a -> Names
freeIn a
after

tileable ::
  Stm GPU ->
  Maybe
    ( SubExp,
      [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
    )
tileable :: Stm GPU
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda GPU, [SubExp], Lambda GPU))
tileable Stm GPU
stm
  | Op (OtherOp (Screma SubExp
w [VName]
arrs ScremaForm GPU
form)) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm,
    Just ([Reduce GPU]
reds, Lambda GPU
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm GPU
form,
    Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
red_nes <- forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce GPU]
reds,
    forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
red_lam, -- No mapout arrays.
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
map_lam,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall shape u. TypeBase shape u -> Bool
primType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
map_lam =
      forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam))
  | Bool
otherwise =
      forall a. Maybe a
Nothing

-- | We classify the inputs to the tiled loop as whether they are
-- tileable (and with what permutation of the kernel indexes) or not.
-- In practice, we should have at least one tileable array per loop,
-- but this is not enforced in our representation.
data InputArray
  = InputTile [Int] VName
  | InputDontTile VName

tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputArray -> Maybe (VName, [Int])
f
  where
    f :: InputArray -> Maybe (VName, [Int])
f (InputTile [Int]
perm VName
arr) = forall a. a -> Maybe a
Just (VName
arr, [Int]
perm)
    f InputDontTile {} = forall a. Maybe a
Nothing

-- | A tile (or an original untiled array).
data InputTile
  = InputTiled [Int] VName
  | InputUntiled VName

-- First VNames are the tiles, second are the untiled.
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles (InputTile [Int]
perm VName
_ : [InputArray]
inputs) (VName
tile : [VName]
tiles) =
  [Int] -> VName -> InputTile
InputTiled [Int]
perm VName
tile forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles (InputDontTile VName
arr : [InputArray]
inputs) [VName]
tiles =
  VName -> InputTile
InputUntiled VName
arr forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles [InputArray]
_ [VName]
_ = []

-- The atual tile size may be smaller for the last tile, so we have to
-- be careful now.
sliceUntiled ::
  MonadBuilder m =>
  VName ->
  SubExp ->
  SubExp ->
  SubExp ->
  m VName
sliceUntiled :: forall (m :: * -> *).
MonadBuilder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
full_tile_size SubExp
this_tile_size = do
  Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
  SubExp
slice_offset <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"slice_offset" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
full_tile_size)
  let slice :: DimIndex SubExp
slice = forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
slice_offset SubExp
this_tile_size (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"untiled_slice" forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
        Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [DimIndex SubExp
slice]

-- | Statements that we insert directly into every thread-private
-- SegMaps.  This is for things that cannot efficiently be computed
-- once in advance in the prelude SegMap, primarily (exclusively?)
-- array slicing operations.
data PrivStms = PrivStms (Stms GPU) ReadPrelude

privStms :: Stms GPU -> PrivStms
privStms :: Stms GPU -> PrivStms
privStms Stms GPU
stms = Stms GPU -> ReadPrelude -> PrivStms
PrivStms Stms GPU
stms forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

addPrivStms :: [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms :: [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [DimIndex SubExp]
local_slice (PrivStms Stms GPU
stms ReadPrelude
readPrelude) = do
  ReadPrelude
readPrelude [DimIndex SubExp]
local_slice
  forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
stms

instance Semigroup PrivStms where
  PrivStms Stms GPU
stms_x ReadPrelude
readPrelude_x <> :: PrivStms -> PrivStms -> PrivStms
<> PrivStms Stms GPU
stms_y ReadPrelude
readPrelude_y =
    Stms GPU -> ReadPrelude -> PrivStms
PrivStms Stms GPU
stms_z ReadPrelude
readPrelude_z
    where
      stms_z :: Stms GPU
stms_z = Stms GPU
stms_x forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms_y
      readPrelude_z :: ReadPrelude
readPrelude_z [DimIndex SubExp]
slice = ReadPrelude
readPrelude_x [DimIndex SubExp]
slice forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ReadPrelude
readPrelude_y [DimIndex SubExp]
slice

instance Monoid PrivStms where
  mempty :: PrivStms
mempty = Stms GPU -> PrivStms
privStms forall a. Monoid a => a
mempty

type ReadPrelude = [DimIndex SubExp] -> Builder GPU ()

data ProcessTileArgs = ProcessTileArgs
  { ProcessTileArgs -> PrivStms
processPrivStms :: PrivStms,
    ProcessTileArgs -> Commutativity
processComm :: Commutativity,
    ProcessTileArgs -> Lambda GPU
processRedLam :: Lambda GPU,
    ProcessTileArgs -> Lambda GPU
processMapLam :: Lambda GPU,
    ProcessTileArgs -> [InputTile]
processTiles :: [InputTile],
    ProcessTileArgs -> [VName]
processAcc :: [VName],
    ProcessTileArgs -> SubExp
processTileId :: SubExp
  }

data ResidualTileArgs = ResidualTileArgs
  { ResidualTileArgs -> PrivStms
residualPrivStms :: PrivStms,
    ResidualTileArgs -> Commutativity
residualComm :: Commutativity,
    ResidualTileArgs -> Lambda GPU
residualRedLam :: Lambda GPU,
    ResidualTileArgs -> Lambda GPU
residualMapLam :: Lambda GPU,
    ResidualTileArgs -> [InputArray]
residualInput :: [InputArray],
    ResidualTileArgs -> [VName]
residualAcc :: [VName],
    ResidualTileArgs -> SubExp
residualInputSize :: SubExp,
    ResidualTileArgs -> SubExp
residualNumWholeTiles :: SubExp
  }

-- | Information about a loop that has been tiled inside a kernel, as
-- well as the kinds of changes that we would then like to perform on
-- the kernel.
data Tiling = Tiling
  { Tiling
-> [Char]
-> SegLevel
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap ::
      String ->
      SegLevel ->
      ResultManifest ->
      (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result) ->
      Builder GPU [VName],
    -- The boolean PrimExp indicates whether they are in-bounds.

    Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
tilingReadTile ::
      TileKind ->
      PrivStms ->
      SubExp ->
      [InputArray] ->
      Builder GPU [InputTile],
    Tiling
-> ProcessTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessTile ::
      ProcessTileArgs ->
      Builder GPU [VName],
    Tiling
-> ResidualTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessResidualTile ::
      ResidualTileArgs ->
      Builder GPU [VName],
    Tiling -> VName -> Builder GPU KernelResult
tilingTileReturns :: VName -> Builder GPU KernelResult,
    Tiling -> SegSpace
tilingSpace :: SegSpace,
    Tiling -> Shape
tilingTileShape :: Shape,
    Tiling -> SegLevel
tilingLevel :: SegLevel,
    Tiling -> Builder GPU SubExp
tilingNumWholeTiles :: Builder GPU SubExp
  }

type DoTiling gtids kdims =
  SegLevel -> gtids -> kdims -> SubExp -> Builder GPU Tiling

scalarLevel :: Tiling -> SegLevel
scalarLevel :: Tiling -> SegLevel
scalarLevel Tiling
tiling =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt
  where
    lvl :: SegLevel
lvl = Tiling -> SegLevel
tilingLevel Tiling
tiling

protectOutOfBounds ::
  String ->
  PrimExp VName ->
  [Type] ->
  Builder GPU Result ->
  Builder GPU [VName]
protectOutOfBounds :: [Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
desc PrimExp VName
in_bounds [Type]
ts Builder GPU Result
m = do
  -- This is more complicated than you might expect, because we need
  -- to be able to produce a blank accumulator, which eBlank cannot
  -- do.  By the linear type rules of accumulators, the body returns
  -- an accumulator of type 'acc_t', then a unique variable of type
  -- 'acc_t' must also be free in the body.  This means we can find it
  -- based just on the type.
  Body GPU
m_body <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody forall a. Monoid a => a
mempty forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Builder GPU Result
m
  let m_body_free :: [VName]
m_body_free = Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Body GPU
m_body
  [(Type, VName)]
t_to_v <-
    forall a. (a -> Bool) -> [a] -> [a]
filter (forall shape u. TypeBase shape u -> Bool
isAcc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a b. [a] -> [b] -> [(a, b)]
zip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
m_body_free forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
m_body_free)
  let blank :: Type
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
blank Type
t = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t) (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Type
t [(Type, VName)]
t_to_v
  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
desc forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp PrimExp VName
in_bounds) (forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPU
m_body) (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Type
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
blank [Type]
ts)

postludeGeneric ::
  Tiling ->
  PrivStms ->
  Pat Type ->
  [VName] ->
  Stms GPU ->
  Result ->
  [Type] ->
  Builder GPU [VName]
postludeGeneric :: Tiling
-> PrivStms
-> Pat Type
-> [VName]
-> Stms GPU
-> Result
-> [Type]
-> BuilderT GPU (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pat Type
pat [VName]
accs' Stms GPU
poststms Result
poststms_res [Type]
res_ts =
  Tiling
-> [Char]
-> SegLevel
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap Tiling
tiling [Char]
"thread_res" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds [DimIndex SubExp]
slice -> do
    -- Read our per-thread result from the tiled loop.
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat Type
pat) [VName]
accs') forall a b. (a -> b) -> a -> b
$ \(VName
us, VName
everyone) -> do
      Type
everyone_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
everyone
      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
us] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
everyone forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
everyone_t [DimIndex SubExp]
slice

    if Stms GPU
poststms forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
      then do
        -- The privstms may still be necessary for the result.
        [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
        forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
poststms_res
      else forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$
        [Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
"postlude" PrimExp VName
in_bounds [Type]
res_ts forall a b. (a -> b) -> a -> b
$ do
          [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
          forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
poststms
          forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
poststms_res

type TiledBody = Names -> PrivStms -> Builder GPU [VName]

tileGeneric ::
  DoTiling gtids kdims ->
  SegLevel ->
  [Type] ->
  Pat Type ->
  gtids ->
  kdims ->
  SubExp ->
  (Commutativity, Lambda GPU, [SubExp], Lambda GPU) ->
  [InputArray] ->
  Stms GPU ->
  Result ->
  TileM (Stms GPU, Tiling, TiledBody)
tileGeneric :: forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pat Type
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
-> [InputArray]
-> Stms GPU
-> Result
-> TileM (Stms GPU, Tiling, TiledBody)
tileGeneric DoTiling gtids kdims
doTiling SegLevel
initial_lvl [Type]
res_ts Pat Type
pat gtids
gtids kdims
kdims SubExp
w (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form [InputArray]
inputs Stms GPU
poststms Result
poststms_res = do
  (Tiling
tiling, Stms GPU
tiling_stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ DoTiling gtids kdims
doTiling SegLevel
initial_lvl gtids
gtids kdims
kdims SubExp
w

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
tiling_stms, Tiling
tiling, Tiling -> TiledBody
tiledBody Tiling
tiling)
  where
    (Commutativity
red_comm, Lambda GPU
red_lam, [SubExp]
red_nes, Lambda GPU
map_lam) = (Commutativity, Lambda GPU, [SubExp], Lambda GPU)
form

    tiledBody :: Tiling -> Names -> PrivStms -> Builder GPU [VName]
    tiledBody :: Tiling -> TiledBody
tiledBody Tiling
tiling Names
_private PrivStms
privstms = do
      let tile_shape :: Shape
tile_shape = Tiling -> Shape
tilingTileShape Tiling
tiling

      SubExp
num_whole_tiles <- Tiling -> Builder GPU SubExp
tilingNumWholeTiles Tiling
tiling

      -- We don't use a Replicate here, because we want to enforce a
      -- scalar memory space.
      [VName]
mergeinits <- Tiling
-> [Char]
-> SegLevel
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap Tiling
tiling [Char]
"mergeinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds [DimIndex SubExp]
slice ->
        -- Constant neutral elements (a common case) do not need protection from OOB.
        if forall a. FreeIn a => a -> Names
freeIn [SubExp]
red_nes forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
red_nes
          else forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$
            [Char]
-> PrimExp VName
-> [Type]
-> Builder GPU Result
-> BuilderT GPU (State VNameSource) [VName]
protectOutOfBounds [Char]
"neutral" PrimExp VName
in_bounds (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
red_lam) forall a b. (a -> b) -> a -> b
$ do
              [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [DimIndex SubExp]
slice PrivStms
privstms
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
red_nes

      [(Param (TypeBase Shape Uniqueness), SubExp)]
merge <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
red_lam) [VName]
mergeinits) forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
mergeinit) ->
        (,)
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam
            (VName -> [Char]
baseString (forall dec. Param dec -> VName
paramName Param Type
p) forall a. [a] -> [a] -> [a]
++ [Char]
"_merge")
            (forall dec. Typed dec => Param dec -> Type
paramType Param Type
p Type -> Shape -> Type
`arrayOfShape` Shape
tile_shape forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique)
          forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
mergeinit)

      VName
tile_id <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tile_id"
      let loopform :: LoopForm GPU
loopform = forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
tile_id IntType
Int64 SubExp
num_whole_tiles []
      Body GPU
loopbody <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm GPU
loopform forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
merge) forall a b. (a -> b) -> a -> b
$ do
            -- Collectively read a tile.
            [InputTile]
tile <- Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
tilingReadTile Tiling
tiling TileKind
TilePartial PrivStms
privstms (VName -> SubExp
Var VName
tile_id) [InputArray]
inputs

            -- Now each thread performs a traversal of the tile and
            -- updates its accumulator.
            let accs :: [VName]
accs =
                  forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param (TypeBase Shape Uniqueness), SubExp)]
merge
                tile_args :: ProcessTileArgs
tile_args =
                  PrivStms
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda GPU
red_lam Lambda GPU
map_lam [InputTile]
tile [VName]
accs (VName -> SubExp
Var VName
tile_id)
            forall {k} (rep :: k). Buildable rep => [SubExp] -> Body rep
resultBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tiling
-> ProcessTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessTile Tiling
tiling ProcessTileArgs
tile_args

      [VName]
accs <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"accs" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase Shape Uniqueness), SubExp)]
merge LoopForm GPU
loopform Body GPU
loopbody

      -- We possibly have to traverse a residual tile.
      Lambda GPU
red_lam' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
red_lam
      Lambda GPU
map_lam' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam
      let residual_args :: ResidualTileArgs
residual_args =
            PrivStms
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [InputArray]
-> [VName]
-> SubExp
-> SubExp
-> ResidualTileArgs
ResidualTileArgs PrivStms
privstms Commutativity
red_comm Lambda GPU
red_lam' Lambda GPU
map_lam' [InputArray]
inputs [VName]
accs SubExp
w SubExp
num_whole_tiles
      [VName]
accs' <- Tiling
-> ResidualTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessResidualTile Tiling
tiling ResidualTileArgs
residual_args

      -- Create a SegMap that takes care of the postlude for every thread.
      Tiling
-> PrivStms
-> Pat Type
-> [VName]
-> Stms GPU
-> Result
-> [Type]
-> BuilderT GPU (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pat Type
pat [VName]
accs' Stms GPU
poststms Result
poststms_res [Type]
res_ts

mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prestms_live_arrs [VName]
prestms_live [DimIndex SubExp]
slice =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prestms_live_arrs [VName]
prestms_live) forall a b. (a -> b) -> a -> b
$ \(VName
arr, VName
v) -> do
      Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [DimIndex SubExp]
slice

tileReturns :: [(VName, SubExp)] -> [(SubExp, SubExp)] -> VName -> Builder GPU KernelResult
tileReturns :: [(VName, SubExp)]
-> [(SubExp, SubExp)] -> VName -> Builder GPU KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp, SubExp)]
dims VName
arr = do
  let unit_dims :: [SubExp]
unit_dims = forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
dims_on_top) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
  VName
arr' <-
    if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => t a -> Bool
null (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t) -- Second check is for accumulators.
      then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
      else do
        let new_shape :: Shape
new_shape = forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ [SubExp]
unit_dims forall a. [a] -> [a] -> [a]
++ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary Shape
new_shape VName
arr
  let tile_dims :: [(SubExp, SubExp)]
tile_dims = forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top) [SubExp]
unit_dims forall a. [a] -> [a] -> [a]
++ [(SubExp, SubExp)]
dims
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns forall a. Monoid a => a
mempty [(SubExp, SubExp)]
tile_dims VName
arr'

is1DTileable :: VName -> M.Map VName Names -> VName -> InputArray
is1DTileable :: VName -> AliasTable -> VName -> InputArray
is1DTileable VName
gtid AliasTable
variance VName
arr
  | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
gtid forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
arr AliasTable
variance =
      [Int] -> VName -> InputArray
InputTile [Int
0] VName
arr
  | Bool
otherwise =
      VName -> InputArray
InputDontTile VName
arr

reconstructGtids1D ::
  Count GroupSize SubExp ->
  VName ->
  VName ->
  VName ->
  Builder GPU ()
reconstructGtids1D :: Count GroupSize SubExp -> VName -> VName -> VName -> Builder GPU ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid =
  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid]
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size) forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid)

readTile1D ::
  SubExp ->
  VName ->
  VName ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  TileKind ->
  PrivStms ->
  SubExp ->
  [InputArray] ->
  Builder GPU [InputTile]
readTile1D :: SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> SegLevel
-> ResultManifest
-> (VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap1D [Char]
"full_tile" SegLevel
lvl ResultManifest
ResultNoSimplify
    forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
      SubExp
j <-
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j"
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid)

      Count GroupSize SubExp -> VName -> VName -> VName -> Builder GPU ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
      [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms

      let arrs :: [VName]
arrs = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs
      [Type]
arr_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
arrs
      let tile_ts :: [Type]
tile_ts = forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
arr_ts
          w :: SubExp
w = forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
arr_ts

      let readTileElem :: VName -> BuilderT GPU (State VNameSource) VName
readTileElem VName
arr =
            -- No need for fullSlice because we are tiling only prims.
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"tile_elem" (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix SubExp
j])
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$
        case TileKind
kind of
          TileKind
TilePartial ->
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"pre1d"
              forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
j forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w)
                (forall {k} (rep :: k). Buildable rep => [SubExp] -> Body rep
resultBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> BuilderT GPU (State VNameSource) VName
readTileElem) [VName]
arrs)
                (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank [Type]
tile_ts)
          TileKind
TileFull ->
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BuilderT GPU (State VNameSource) VName
readTileElem [VName]
arrs
  where
    lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt

processTile1D ::
  VName ->
  VName ->
  SubExp ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ProcessTileArgs ->
  Builder GPU [VName]
processTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args = do
  let red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
      privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
      map_lam :: Lambda GPU
map_lam = ProcessTileArgs -> Lambda GPU
processMapLam ProcessTileArgs
tile_args
      red_lam :: Lambda GPU
red_lam = ProcessTileArgs -> Lambda GPU
processRedLam ProcessTileArgs
tile_args
      tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
      tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args
      accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args

  [Char]
-> SegLevel
-> ResultManifest
-> (VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap1D [Char]
"acc" SegLevel
lvl ResultManifest
ResultPrivate forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
    Count GroupSize SubExp -> VName -> VName -> VName -> Builder GPU ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
    [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms

    -- We replace the neutral elements with the accumulators (this is
    -- OK because the parallel semantics are not used after this
    -- point).
    [SubExp]
thread_accs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs forall a b. (a -> b) -> a -> b
$ \VName
acc ->
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"acc" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid]
    let sliceTile :: InputTile -> BuilderT GPU (State VNameSource) VName
sliceTile (InputTiled [Int]
_ VName
arr) =
          forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
        sliceTile (InputUntiled VName
arr) =
          forall (m :: * -> *).
MonadBuilder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
tile_size

    [VName]
tiles' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BuilderT GPU (State VNameSource) VName
sliceTile [InputTile]
tiles

    let form' :: ScremaForm GPU
form' = forall {k} (rep :: k). [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
thread_accs] Lambda GPU
map_lam
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
gtid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim)
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
tile_size [VName]
tiles' ScremaForm GPU
form'])
          (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
thread_accs)
  where
    lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt

processResidualTile1D ::
  VName ->
  VName ->
  SubExp ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ResidualTileArgs ->
  Builder GPU [VName]
processResidualTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ResidualTileArgs
args = do
  -- The number of residual elements that are not covered by
  -- the whole tiles.
  SubExp
residual_input <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"residual_input" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size

  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc_after_residual"
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
      (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
      (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
      (SubExp -> BuilderT GPU (State VNameSource) (Body GPU)
nonemptyTile SubExp
residual_input)
  where
    red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
    map_lam :: Lambda GPU
map_lam = ResidualTileArgs -> Lambda GPU
residualMapLam ResidualTileArgs
args
    red_lam :: Lambda GPU
red_lam = ResidualTileArgs -> Lambda GPU
residualRedLam ResidualTileArgs
args
    privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
    inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
    accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
    num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
    w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args

    nonemptyTile :: SubExp -> BuilderT GPU (State VNameSource) (Body GPU)
nonemptyTile SubExp
residual_input = forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$ do
      -- Collectively construct a tile.  Threads that are out-of-bounds
      -- provide a blank dummy value.
      [InputTile]
full_tiles <-
        SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile1D
          SubExp
tile_size
          VName
gid
          VName
gtid
          Count NumGroups SubExp
num_groups
          Count GroupSize SubExp
group_size
          TileKind
TilePartial
          PrivStms
privstms
          SubExp
num_whole_tiles
          [InputArray]
inputs

      let sliceTile :: InputTile -> BuilderT GPU (State VNameSource) InputTile
sliceTile (InputUntiled VName
arr) =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr
          sliceTile (InputTiled [Int]
perm VName
tile) = do
            let slice :: DimIndex SubExp
slice =
                  forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
            [Int] -> VName -> InputTile
InputTiled [Int]
perm
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"partial_tile" (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp
slice])

      [InputTile]
tiles <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BuilderT GPU (State VNameSource) InputTile
sliceTile [InputTile]
full_tiles

      -- Now each thread performs a traversal of the tile and
      -- updates its accumulator.
      let tile_args :: ProcessTileArgs
tile_args =
            PrivStms
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda GPU
red_lam Lambda GPU
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles
      forall {k} (rep :: k). Buildable rep => [SubExp] -> Body rep
resultBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
residual_input Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args

tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d [(VName, SubExp)]
dims_on_top SegLevel
initial_lvl VName
gtid SubExp
kdim SubExp
w = do
  VName
gid <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid"
  VName
gid_flat <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"

  (SegLevel
lvl, SegSpace
space) <-
    if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top
      then
        forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl) forall a b. (a -> b) -> a -> b
$ SegLevel -> SegVirt
segVirt SegLevel
initial_lvl,
            VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat [(VName
gid, forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl)]
          )
      else do
        SubExp
group_size <-
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"computed_group_size" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int64) (forall {k} (u :: k) e. Count u e -> e
unCount (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl)) SubExp
kdim

        -- How many groups we need to exhaust the innermost dimension.
        SubExp
ldim <-
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"ldim" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim SubExp
group_size

        SubExp
num_groups <-
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"computed_num_groups"
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
ldim (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)

        forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_groups) (forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt,
            VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
dims_on_top forall a. [a] -> [a] -> [a]
++ [(VName
gid, SubExp
ldim)]
          )
  let tile_size :: SubExp
tile_size = forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Tiling
      { tilingSegMap :: [Char]
-> SegLevel
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap = \[Char]
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result
f -> [Char]
-> SegLevel
-> ResultManifest
-> (VName -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap1D [Char]
desc SegLevel
lvl' ResultManifest
manifest forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
          forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid]
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid)
          PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result
f (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
gtid forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid],
        tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Builder GPU [InputTile]
tilingReadTile =
          SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessTile :: ProcessTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessTile =
          VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessResidualTile :: ResidualTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessResidualTile =
          VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingTileReturns :: VName -> Builder GPU KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)] -> VName -> Builder GPU KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim, SubExp
tile_size)],
        tilingTileShape :: Shape
tilingTileShape = forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size],
        tilingNumWholeTiles :: Builder GPU SubExp
tilingNumWholeTiles =
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_whole_tiles" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
        tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl,
        tilingSpace :: SegSpace
tilingSpace = SegSpace
space
      }

invariantToOneOfTwoInnerDims ::
  Names ->
  M.Map VName Names ->
  [VName] ->
  VName ->
  Maybe InputArray
invariantToOneOfTwoInnerDims :: Names -> AliasTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant AliasTable
variance [VName]
dims VName
arr = do
  VName
j : VName
i : [VName]
_ <- forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [VName]
dims
  let variant_to :: Names
variant_to = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
arr AliasTable
variance
      branch_invariant :: Bool
branch_invariant = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
j Names
branch_variant Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
  if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`notNameIn` Names
variant_to
    then forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
0, Int
1] VName
arr
    else
      if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`notNameIn` Names
variant_to
        then forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
1, Int
0] VName
arr
        else forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> InputArray
InputDontTile VName
arr

-- Reconstruct the original gtids from group and local IDs.
reconstructGtids2D ::
  SubExp ->
  (VName, VName) ->
  (VName, VName) ->
  (VName, VName) ->
  Builder GPU ()
reconstructGtids2D :: SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> Builder GPU ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y) = do
  -- Reconstruct the original gtids from gid_x/gid_y and ltid_x/ltid_y.
  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_x]
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_x forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
gtid_y]
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
gid_y forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)

readTile2D ::
  (SubExp, SubExp) ->
  (VName, VName) ->
  (VName, VName) ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  TileKind ->
  PrivStms ->
  SubExp ->
  [InputArray] ->
  Builder GPU [InputTile]
readTile2D :: (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap2D
      [Char]
"full_tile"
      (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [])))
      ResultManifest
ResultNoSimplify
      (SubExp
tile_size, SubExp
tile_size)
    forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
      SubExp
i <-
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"i"
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
      SubExp
j <-
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j"
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size forall a. Num a => a -> a -> a
+ forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)

      SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> Builder GPU ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
      [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms

      let arrs_and_perms :: [(VName, [Int])]
arrs_and_perms = [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs

          readTileElem :: (VName, [Int]) -> BuilderT GPU (State VNameSource) VName
readTileElem (VName
arr, [Int]
perm) =
            -- No need for fullSlice because we are tiling only prims.
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp
              [Char]
"tile_elem"
              ( forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
                  forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last forall a b. (a -> b) -> a -> b
$ forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]]
              )

          readTileElemIfInBounds :: (VName, [Int]) -> BuilderT GPU (State VNameSource) (Exp GPU)
readTileElemIfInBounds (VName
arr, [Int]
perm) = do
            Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
            let tile_t :: Type
tile_t = forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_t
                w :: SubExp
w = forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t
                idx :: SubExp
idx = forall a. [a] -> a
last forall a b. (a -> b) -> a -> b
$ forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]
                othercheck :: TPrimExp Bool VName
othercheck =
                  forall a. [a] -> a
last forall a b. (a -> b) -> a -> b
$
                    forall a. [Int] -> [a] -> [a]
rearrangeShape
                      [Int]
perm
                      [ forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y,
                        forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x
                      ]
            forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
              (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
idx forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Bool VName
othercheck)
              (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix SubExp
idx]])
              (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
tile_t])

      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$
        case TileKind
kind of
          TileKind
TilePartial ->
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"pre2d" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (VName, [Int]) -> BuilderT GPU (State VNameSource) (Exp GPU)
readTileElemIfInBounds) [(VName, [Int])]
arrs_and_perms
          TileKind
TileFull ->
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, [Int]) -> BuilderT GPU (State VNameSource) VName
readTileElem [(VName, [Int])]
arrs_and_perms

findTileSize :: HasScope rep m => [InputTile] -> m SubExp
findTileSize :: forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
[InputTile] -> m SubExp
findTileSize [InputTile]
tiles =
  case forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputTile -> Maybe VName
isTiled [InputTile]
tiles of
    VName
v : [VName]
_ -> forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
    [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
  where
    isTiled :: InputTile -> Maybe VName
isTiled InputUntiled {} = forall a. Maybe a
Nothing
    isTiled (InputTiled [Int]
_ VName
tile) = forall a. a -> Maybe a
Just VName
tile

processTile2D ::
  (VName, VName) ->
  (VName, VName) ->
  (SubExp, SubExp) ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ProcessTileArgs ->
  Builder GPU [VName]
processTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args = do
  let privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
      red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
      red_lam :: Lambda GPU
red_lam = ProcessTileArgs -> Lambda GPU
processRedLam ProcessTileArgs
tile_args
      map_lam :: Lambda GPU
map_lam = ProcessTileArgs -> Lambda GPU
processMapLam ProcessTileArgs
tile_args
      tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
      accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args
      tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args

  -- Might be truncated in case of a partial tile.
  SubExp
actual_tile_size <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
[InputTile] -> m SubExp
findTileSize [InputTile]
tiles

  [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap2D
    [Char]
"acc"
    (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [])))
    ResultManifest
ResultPrivate
    (SubExp
tile_size, SubExp
tile_size)
    forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
      SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> Builder GPU ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)

      [DimIndex SubExp] -> PrivStms -> Builder GPU ()
addPrivStms [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms

      -- We replace the neutral elements with the accumulators (this is
      -- OK because the parallel semantics are not used after this
      -- point).
      [SubExp]
thread_accs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs forall a b. (a -> b) -> a -> b
$ \VName
acc ->
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"acc" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y]
      let form' :: ScremaForm GPU
form' = forall {k} (rep :: k). [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
red_comm Lambda GPU
red_lam [SubExp]
thread_accs] Lambda GPU
map_lam

          sliceTile :: InputTile -> BuilderT GPU (State VNameSource) VName
sliceTile (InputUntiled VName
arr) =
            forall (m :: * -> *).
MonadBuilder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
actual_tile_size
          sliceTile (InputTiled [Int]
perm VName
tile) = do
            Type
tile_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
tile
            let idx :: DimIndex SubExp
idx = forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [VName
ltid_x, VName
ltid_y]
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"tile" forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                VName -> Slice SubExp -> BasicOp
Index VName
tile forall a b. (a -> b) -> a -> b
$
                  Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt Type
tile_t (forall a. [a] -> a
head [Int]
perm) [DimIndex SubExp
idx]

      [VName]
tiles' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BuilderT GPU (State VNameSource) VName
sliceTile [InputTile]
tiles

      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc"
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            ( forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
            )
            (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
actual_tile_size [VName]
tiles' ScremaForm GPU
form'])
            (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
thread_accs)

processResidualTile2D ::
  (VName, VName) ->
  (VName, VName) ->
  (SubExp, SubExp) ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ResidualTileArgs ->
  Builder GPU [VName]
processResidualTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processResidualTile2D
  (VName, VName)
gids
  (VName, VName)
gtids
  (SubExp, SubExp)
kdims
  SubExp
tile_size
  Count NumGroups SubExp
num_groups
  Count GroupSize SubExp
group_size
  ResidualTileArgs
args = do
    -- The number of residual elements that are not covered by
    -- the whole tiles.
    SubExp
residual_input <-
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"residual_input" forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size

    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc_after_residual"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
        (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
        (SubExp -> BuilderT GPU (State VNameSource) (Body GPU)
nonemptyTile SubExp
residual_input)
    where
      privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
      red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
      red_lam :: Lambda GPU
red_lam = ResidualTileArgs -> Lambda GPU
residualRedLam ResidualTileArgs
args
      map_lam :: Lambda GPU
map_lam = ResidualTileArgs -> Lambda GPU
residualMapLam ResidualTileArgs
args
      accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
      inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
      num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
      w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args

      nonemptyTile :: SubExp -> BuilderT GPU (State VNameSource) (Body GPU)
nonemptyTile SubExp
residual_input = forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$ do
        -- Collectively construct a tile.  Threads that are out-of-bounds
        -- provide a blank dummy value.
        [InputTile]
full_tile <-
          (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile2D
            (SubExp, SubExp)
kdims
            (VName, VName)
gtids
            (VName, VName)
gids
            SubExp
tile_size
            Count NumGroups SubExp
num_groups
            Count GroupSize SubExp
group_size
            TileKind
TilePartial
            PrivStms
privstms
            SubExp
num_whole_tiles
            [InputArray]
inputs

        let slice :: DimIndex SubExp
slice =
              forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
        [InputTile]
tiles <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [InputTile]
full_tile forall a b. (a -> b) -> a -> b
$ \case
          InputTiled [Int]
perm VName
tile' ->
            [Int] -> VName -> InputTile
InputTiled [Int]
perm
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"partial_tile" (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile' (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp
slice, DimIndex SubExp
slice]))
          InputUntiled VName
arr ->
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr

        let tile_args :: ProcessTileArgs
tile_args =
              PrivStms
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda GPU
red_lam Lambda GPU
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles

        -- Now each thread performs a traversal of the tile and
        -- updates its accumulator.
        forall {k} (rep :: k). Buildable rep => [SubExp] -> Body rep
resultBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile2D
            (VName, VName)
gids
            (VName, VName)
gtids
            (SubExp, SubExp)
kdims
            SubExp
tile_size
            Count NumGroups SubExp
num_groups
            Count GroupSize SubExp
group_size
            ProcessTileArgs
tile_args

tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d [(VName, SubExp)]
dims_on_top SegLevel
_initial_lvl (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
w = do
  VName
gid_x <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_x"
  VName
gid_y <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_y"

  Name
tile_size_key <- [Char] -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tile_size"
  SubExp
tile_size <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"tile_size" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tile_size_key SizeClass
SizeTile
  SubExp
group_size <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"group_size" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
tile_size SubExp
tile_size

  SubExp
num_groups_x <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_groups_x" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_x SubExp
tile_size
  SubExp
num_groups_y <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_groups_y" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_y SubExp
tile_size

  SubExp
num_groups <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_groups_top"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp
        (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
        SubExp
num_groups_x
        (SubExp
num_groups_y forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)

  VName
gid_flat <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gid_flat"
  let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_groups) (forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size) (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []))
      space :: SegSpace
space =
        VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat forall a b. (a -> b) -> a -> b
$
          [(VName, SubExp)]
dims_on_top forall a. [a] -> [a] -> [a]
++ [(VName
gid_x, SubExp
num_groups_x), (VName
gid_y, SubExp
num_groups_y)]

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Tiling
      { tilingSegMap :: [Char]
-> SegLevel
-> ResultManifest
-> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
tilingSegMap = \[Char]
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result
f ->
          [Char]
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Builder GPU Result)
-> BuilderT GPU (State VNameSource) [VName]
segMap2D [Char]
desc SegLevel
lvl' ResultManifest
manifest (SubExp
tile_size, SubExp
tile_size) forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
            SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> Builder GPU ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
            PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result
f
              ( forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
                  forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
              )
              [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y],
        tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Builder GPU [InputTile]
tilingReadTile = (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Builder GPU [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessTile :: ProcessTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessResidualTile :: ResidualTileArgs -> BuilderT GPU (State VNameSource) [VName]
tilingProcessResidualTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BuilderT GPU (State VNameSource) [VName]
processResidualTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingTileReturns :: VName -> Builder GPU KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)] -> VName -> Builder GPU KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim_x, SubExp
tile_size), (SubExp
kdim_y, SubExp
tile_size)],
        tilingTileShape :: Shape
tilingTileShape = forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size, SubExp
tile_size],
        tilingNumWholeTiles :: Builder GPU SubExp
tilingNumWholeTiles =
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_whole_tiles" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
        tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl,
        tilingSpace :: SegSpace
tilingSpace = SegSpace
space
      }