{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

-- | Perform a restricted form of loop tiling within SegMaps.  We only
-- tile primitive types, to avoid excessive local memory use.
module Futhark.Optimise.TileLoops (tileLoops) where

import Control.Monad.Reader
import Control.Monad.State
import qualified Data.Map.Strict as M
import Data.Maybe (mapMaybe)
import qualified Data.Sequence as Seq
import qualified Futhark.Analysis.Alias as Alias
import Futhark.IR.Kernels
import Futhark.IR.Prop.Aliases (consumedInStm)
import Futhark.MonadFreshNames
import Futhark.Optimise.BlkRegTiling
import Futhark.Optimise.TileLoops.Shared
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)

-- | The pass definition.
tileLoops :: Pass Kernels Kernels
tileLoops :: Pass Kernels Kernels
tileLoops =
  String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"tile loops" String
"Tile stream loops inside kernels" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
    (Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall {m :: * -> *}.
MonadFreshNames m =>
Scope Kernels -> Stms Kernels -> m (Stms Kernels)
onStms
  where
    onStms :: Scope Kernels -> Stms Kernels -> m (Stms Kernels)
onStms Scope Kernels
scope Stms Kernels
stms =
      (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels))
-> (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
        State VNameSource (Stms Kernels)
-> VNameSource -> (Stms Kernels, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms Kernels)
 -> VNameSource -> (Stms Kernels, VNameSource))
-> State VNameSource (Stms Kernels)
-> VNameSource
-> (Stms Kernels, VNameSource)
forall a b. (a -> b) -> a -> b
$
          ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms) Scope Kernels
scope

optimiseBody :: Body Kernels -> TileM (Body Kernels)
optimiseBody :: BodyT Kernels -> TileM (BodyT Kernels)
optimiseBody (Body () Stms Kernels
stms Result
res) =
  BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (Stms Kernels -> Result -> BodyT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (Result -> BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms ReaderT
  (Scope Kernels) (State VNameSource) (Result -> BodyT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) Result
-> TileM (BodyT Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope Kernels) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseStms :: Stms Kernels -> TileM (Stms Kernels)
optimiseStms :: Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms =
  Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
    [Stms Kernels] -> Stms Kernels
forall a. Monoid a => [a] -> a
mconcat ([Stms Kernels] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stms Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stms Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStm (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)

optimiseStm :: Stm Kernels -> TileM (Stms Kernels)
optimiseStm :: Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStm stm :: Stm Kernels
stm@(Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp (SegMap lvl :: SegLevel
lvl@SegThread {} SegSpace
space [Type]
ts KernelBody Kernels
kbody)))) = do
  Maybe (Stms Kernels, Stm Kernels)
res3dtiling <- Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
doRegTiling3D Stm Kernels
stm
  case Maybe (Stms Kernels, Stm Kernels)
res3dtiling of
    Just (Stms Kernels
extra_bnds, Stm Kernels
stmt') -> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
extra_bnds Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stmt')
    Maybe (Stms Kernels, Stm Kernels)
Nothing -> do
      Maybe (Stms Kernels, Stm Kernels)
blkRegTiling_res <- Stm Kernels -> TileM (Maybe (Stms Kernels, Stm Kernels))
mmBlkRegTiling Stm Kernels
stm
      case Maybe (Stms Kernels, Stm Kernels)
blkRegTiling_res of
        Just (Stms Kernels
extra_bnds, Stm Kernels
stmt') -> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
extra_bnds Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stmt')
        Maybe (Stms Kernels, Stm Kernels)
Nothing -> do
          (Stms Kernels
host_stms, (SegLevel
lvl', SegSpace
space', KernelBody Kernels
kbody')) <- Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody Names
forall a. Monoid a => a
mempty AliasTable
initial_variance SegLevel
lvl SegSpace
space [Type]
ts KernelBody Kernels
kbody
          Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels
host_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl' SegSpace
space' [Type]
ts KernelBody Kernels
kbody')
  where
    initial_variance :: AliasTable
initial_variance = (NameInfo Any -> Names) -> Map VName (NameInfo Any) -> AliasTable
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> AliasTable)
-> Map VName (NameInfo Any) -> AliasTable
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
optimiseStm (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e) =
  Stm Kernels -> Stms Kernels
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> ExpT Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise ExpT Kernels
e)
  where
    optimise :: Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper {mapOnBody :: Scope Kernels -> BodyT Kernels -> TileM (BodyT Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels -> TileM (BodyT Kernels) -> TileM (BodyT Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (TileM (BodyT Kernels) -> TileM (BodyT Kernels))
-> (BodyT Kernels -> TileM (BodyT Kernels))
-> BodyT Kernels
-> TileM (BodyT Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT Kernels -> TileM (BodyT Kernels)
optimiseBody}

tileInKernelBody ::
  Names ->
  VarianceTable ->
  SegLevel ->
  SegSpace ->
  [Type] ->
  KernelBody Kernels ->
  TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody :: Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody Names
branch_variant AliasTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts KernelBody Kernels
kbody
  | Just Result
kbody_res <- (KernelResult -> Maybe SubExp) -> [KernelResult] -> Maybe Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> Maybe SubExp
isSimpleResult ([KernelResult] -> Maybe Result) -> [KernelResult] -> Maybe Result
forall a b. (a -> b) -> a -> b
$ KernelBody Kernels -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody Kernels
kbody = do
    Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled <-
      Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> BodyT Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody Names
branch_variant AliasTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts (BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
        BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (KernelBody Kernels -> Stms Kernels
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody Kernels
kbody) Result
kbody_res
    case Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled of
      Just (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) -> do
        ([KernelResult]
res', Stms Kernels
stms') <-
          Binder Kernels [KernelResult]
-> ReaderT
     (Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult]
 -> ReaderT
      (Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> ReaderT
     (Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> [VName] -> Binder Kernels [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tiling -> VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns Tiling
tiling) ([VName] -> Binder Kernels [KernelResult])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [KernelResult]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TiledBody
tiledBody Names
forall a. Monoid a => a
mempty PrivStms
forall a. Monoid a => a
mempty
        (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( Stms Kernels
host_stms,
            ( Tiling -> SegLevel
tilingLevel Tiling
tiling,
              Tiling -> SegSpace
tilingSpace Tiling
tiling,
              BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms' [KernelResult]
res'
            )
          )
      Maybe (Stms Kernels, Tiling, TiledBody)
Nothing ->
        (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody Kernels
kbody))
  | Bool
otherwise =
    (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody Kernels
kbody))
  where
    isSimpleResult :: KernelResult -> Maybe SubExp
isSimpleResult (Returns ResultManifest
_ SubExp
se) = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
se
    isSimpleResult KernelResult
_ = Maybe SubExp
forall a. Maybe a
Nothing

tileInBody ::
  Names ->
  VarianceTable ->
  SegLevel ->
  SegSpace ->
  [Type] ->
  Body Kernels ->
  TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody :: Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> BodyT Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody Names
branch_variant AliasTable
initial_variance SegLevel
initial_lvl SegSpace
initial_space [Type]
res_ts (Body () Stms Kernels
initial_kstms Result
stms_res) =
  Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend Stms Kernels
forall a. Monoid a => a
mempty ([Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
initial_kstms
  where
    variance :: AliasTable
variance = AliasTable -> Stms Kernels -> AliasTable
varianceInStms AliasTable
initial_variance Stms Kernels
initial_kstms

    descend :: Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend Stms Kernels
_ [] =
      Maybe (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stms Kernels, Tiling, TiledBody)
forall a. Maybe a
Nothing
    descend Stms Kernels
prestms (Stm Kernels
stm_to_tile : [Stm Kernels]
poststms)
      -- 2D tiling of redomap.
      | ([VName]
gtids, Result
kdims) <- [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], Result))
-> [(VName, SubExp)] -> ([VName], Result)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
        Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form) <- Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm_to_tile,
        Just [InputArray]
inputs <-
          (VName -> Maybe InputArray) -> [VName] -> Maybe [InputArray]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Names -> AliasTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant AliasTable
variance [VName]
gtids) [VName]
arrs,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(VName, [Int])] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([(VName, [Int])] -> Bool) -> [(VName, [Int])] -> Bool
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
        VName
gtid_y : VName
gtid_x : [VName]
top_gtids_rev <- [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
gtids,
        SubExp
kdim_y : SubExp
kdim_x : Result
top_kdims_rev <- Result -> Result
forall a. [a] -> [a]
reverse Result
kdims,
        (Stms Kernels
prestms', Stms Kernels
poststms') <-
          AliasTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude AliasTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms),
        Names
used <- Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
poststms' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
        (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ((Stms Kernels, Tiling, TiledBody)
    -> (Stms Kernels, Tiling, TiledBody))
-> (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> AliasTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space AliasTable
variance Stms Kernels
prestms' Names
used
          ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DoTiling (VName, VName) (SubExp, SubExp)
-> SegLevel
-> [Type]
-> Pattern Kernels
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileGeneric
            ([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d ([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp))
-> [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
top_gtids_rev Result
top_kdims_rev)
            SegLevel
initial_lvl
            [Type]
res_ts
            (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
            (VName
gtid_x, VName
gtid_y)
            (SubExp
kdim_x, SubExp
kdim_y)
            SubExp
w
            (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form
            [InputArray]
inputs
            Stms Kernels
poststms'
            Result
stms_res
      -- 1D tiling of redomap.
      | (VName
gtid, SubExp
kdim) : [(VName, SubExp)]
top_space_rev <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
        Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form) <- Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm_to_tile,
        [InputArray]
inputs <- (VName -> InputArray) -> [VName] -> [InputArray]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> AliasTable -> VName -> InputArray
is1DTileable VName
gtid AliasTable
variance) [VName]
arrs,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(VName, [Int])] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([(VName, [Int])] -> Bool) -> [(VName, [Int])] -> Bool
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
gtid VName -> Names -> Bool
`nameIn` Names
branch_variant,
        (Stms Kernels
prestms', Stms Kernels
poststms') <-
          AliasTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude AliasTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms),
        Names
used <- Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
poststms' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =
        (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ((Stms Kernels, Tiling, TiledBody)
    -> (Stms Kernels, Tiling, TiledBody))
-> (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> AliasTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space AliasTable
variance Stms Kernels
prestms' Names
used
          ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DoTiling VName SubExp
-> SegLevel
-> [Type]
-> Pattern Kernels
-> VName
-> SubExp
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileGeneric
            ([(VName, SubExp)] -> DoTiling VName SubExp
tiling1d ([(VName, SubExp)] -> DoTiling VName SubExp)
-> [(VName, SubExp)] -> DoTiling VName SubExp
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
top_space_rev)
            SegLevel
initial_lvl
            [Type]
res_ts
            (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
            VName
gtid
            SubExp
kdim
            SubExp
w
            (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form
            [InputArray]
inputs
            Stms Kernels
poststms'
            Result
stms_res
      -- Tiling inside for-loop.
      | DoLoop [] [(FParam Kernels, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
bound []) BodyT Kernels
loopbody <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm_to_tile,
        (Stms Kernels
prestms', Stms Kernels
poststms') <-
          AliasTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude AliasTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms) = do
        let branch_variant' :: Names
branch_variant' =
              Names
branch_variant
                Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
                  ( (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map
                      ((VName -> AliasTable -> Names) -> AliasTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) AliasTable
variance)
                      (Names -> [VName]
namesToList (SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
bound))
                  )
            merge_params :: [Param (TypeBase Shape Uniqueness)]
merge_params = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge

        Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled <-
          Scope Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (VName -> NameInfo Kernels -> Scope Kernels -> Scope Kernels
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo Kernels
forall lore. IntType -> NameInfo lore
IndexName IntType
it) (Scope Kernels -> Scope Kernels) -> Scope Kernels -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
merge_params) (TileM (Maybe (Stms Kernels, Tiling, TiledBody))
 -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
            Names
-> AliasTable
-> SegLevel
-> SegSpace
-> [Type]
-> BodyT Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody
              Names
branch_variant'
              AliasTable
variance
              SegLevel
initial_lvl
              SegSpace
initial_space
              ((Param (TypeBase Shape Uniqueness) -> Type)
-> [Param (TypeBase Shape Uniqueness)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
merge_params)
              (BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> BodyT Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms BodyT Kernels
loopbody) (BodyT Kernels -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT Kernels
loopbody)

        case Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled of
          Maybe (Stms Kernels, Tiling, TiledBody)
Nothing -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next
          Just (Stms Kernels, Tiling, TiledBody)
tiled ->
            (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just
              ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegSpace
-> AliasTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> [Type]
-> Pattern Kernels
-> StmAux (ExpDec Kernels)
-> [(FParam Kernels, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileDoLoop
                SegSpace
initial_space
                AliasTable
variance
                Stms Kernels
prestms'
                (BodyT Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT Kernels
loopbody Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param (TypeBase Shape Uniqueness), SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge)
                (Stms Kernels, Tiling, TiledBody)
tiled
                [Type]
res_ts
                (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
                (Stm Kernels -> StmAux (ExpDec Kernels)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm Kernels
stm_to_tile)
                [(FParam Kernels, SubExp)]
merge
                VName
i
                IntType
it
                SubExp
bound
                Stms Kernels
poststms'
                Result
stms_res
      | Bool
otherwise = TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next
      where
        next :: TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next =
          Scope Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stm Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm Kernels
stm_to_tile) (TileM (Maybe (Stms Kernels, Tiling, TiledBody))
 -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
            Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend (Stms Kernels
prestms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stm_to_tile) [Stm Kernels]
poststms

-- | Move statements from prelude to postlude if they are not used in
-- the tiled statement anyway.
preludeToPostlude ::
  VarianceTable ->
  Stms Kernels ->
  Stm Kernels ->
  Stms Kernels ->
  (Stms Kernels, Stms Kernels)
preludeToPostlude :: AliasTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude AliasTable
variance Stms Kernels
prelude Stm Kernels
stm_to_tile Stms Kernels
postlude =
  (Stms Kernels
prelude_used, Stms Kernels
prelude_not_used Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
postlude)
  where
    used_in_tiled :: Names
used_in_tiled = Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile

    used_in_stm_variant :: Names
used_in_stm_variant =
      (Names
used_in_tiled Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>) (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
        [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
          (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> AliasTable -> Names) -> AliasTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) AliasTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$
            Names -> [VName]
namesToList Names
used_in_tiled

    used :: Stm Kernels -> Bool
used Stm Kernels
stm =
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_in_stm_variant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
        PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm

    (Stms Kernels
prelude_used, Stms Kernels
prelude_not_used) =
      (Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm Kernels -> Bool
used Stms Kernels
prelude

-- | Partition prelude statements preceding a tiled loop (or something
-- containing a tiled loop) into three categories:
--
-- 1) Group-level statements that are invariant to the threads in the group.
--
-- 2) Thread-variant statements that should be computed once with a segmap_thread_scalar.
--
-- 3) Thread-variant statements that should be recomputed whenever
-- they are needed.
--
-- The third category duplicates computation, so we only want to do it
-- when absolutely necessary.  Currently, this is necessary for
-- results that are views of an array (slicing, rotate, etc) and which
-- results are used after the prelude, because these cannot be
-- efficiently represented by a scalar segmap (they'll be manifested
-- in memory).
partitionPrelude ::
  VarianceTable ->
  Stms Kernels ->
  Names ->
  Names ->
  (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude :: AliasTable
-> Stms Kernels
-> Names
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude AliasTable
variance Stms Kernels
prestms Names
private Names
used_after =
  (Stms Kernels
invariant_prestms, Stms Kernels
precomputed_variant_prestms, Stms Kernels
recomputed_variant_prestms)
  where
    invariantTo :: Names -> Stm Kernels -> Bool
invariantTo Names
names Stm Kernels
stm =
      case PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm) of
        [] -> Bool
True -- Does not matter.
        VName
v : [VName]
_ -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
names) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v AliasTable
variance

    consumed :: VName -> Bool
consumed VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
consumed_in_prestms
    consumedStm :: Stm Kernels -> Bool
consumedStm Stm Kernels
stm = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
consumed (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm))

    later_consumed :: Names
later_consumed =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
        (Stm Kernels -> [VName]) -> [Stm Kernels] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm Kernels -> PatternT Type) -> Stm Kernels -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm Kernels] -> [VName]) -> [Stm Kernels] -> [VName]
forall a b. (a -> b) -> a -> b
$
          Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms Kernels -> [Stm Kernels]) -> Stms Kernels -> [Stm Kernels]
forall a b. (a -> b) -> a -> b
$ (Stm Kernels -> Bool) -> Stms Kernels -> Stms Kernels
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm Kernels -> Bool
consumedStm Stms Kernels
prestms

    groupInvariant :: Stm Kernels -> Bool
groupInvariant Stm Kernels
stm =
      Names -> Stm Kernels -> Bool
invariantTo Names
private Stm Kernels
stm
        Bool -> Bool -> Bool
&& Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
later_consumed) (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm)))
        Bool -> Bool -> Bool
&& Names -> Stm Kernels -> Bool
invariantTo Names
later_consumed Stm Kernels
stm
    (Stms Kernels
invariant_prestms, Stms Kernels
variant_prestms) =
      (Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm Kernels -> Bool
groupInvariant Stms Kernels
prestms

    consumed_in_prestms :: Names
consumed_in_prestms =
      (Stm (Aliases Kernels) -> Names)
-> Seq (Stm (Aliases Kernels)) -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm (Aliases Kernels) -> Names
forall lore. Aliased lore => Stm lore -> Names
consumedInStm (Seq (Stm (Aliases Kernels)) -> Names)
-> Seq (Stm (Aliases Kernels)) -> Names
forall a b. (a -> b) -> a -> b
$ (Seq (Stm (Aliases Kernels)), AliasesAndConsumed)
-> Seq (Stm (Aliases Kernels))
forall a b. (a, b) -> a
fst ((Seq (Stm (Aliases Kernels)), AliasesAndConsumed)
 -> Seq (Stm (Aliases Kernels)))
-> (Seq (Stm (Aliases Kernels)), AliasesAndConsumed)
-> Seq (Stm (Aliases Kernels))
forall a b. (a -> b) -> a -> b
$ AliasTable
-> Stms Kernels
-> (Seq (Stm (Aliases Kernels)), AliasesAndConsumed)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable
-> Stms lore -> (Stms (Aliases lore), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty Stms Kernels
prestms

    mustBeInlinedExp :: ExpT lore -> Bool
mustBeInlinedExp (BasicOp (Index VName
_ Slice SubExp
slice)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
    mustBeInlinedExp (BasicOp Rotate {}) = Bool
True
    mustBeInlinedExp (BasicOp Rearrange {}) = Bool
True
    mustBeInlinedExp (BasicOp Reshape {}) = Bool
True
    mustBeInlinedExp ExpT lore
_ = Bool
False
    mustBeInlined :: Stm Kernels -> Bool
mustBeInlined Stm Kernels
stm =
      ExpT Kernels -> Bool
forall {lore}. ExpT lore -> Bool
mustBeInlinedExp (Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm)
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_after) (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm))

    must_be_inlined :: Names
must_be_inlined =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
        (Stm Kernels -> [VName]) -> [Stm Kernels] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm Kernels -> PatternT Type) -> Stm Kernels -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm Kernels] -> [VName]) -> [Stm Kernels] -> [VName]
forall a b. (a -> b) -> a -> b
$
          Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms Kernels -> [Stm Kernels]) -> Stms Kernels -> [Stm Kernels]
forall a b. (a -> b) -> a -> b
$ (Stm Kernels -> Bool) -> Stms Kernels -> Stms Kernels
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm Kernels -> Bool
mustBeInlined Stms Kernels
variant_prestms
    recompute :: Stm Kernels -> Bool
recompute Stm Kernels
stm =
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
must_be_inlined) (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm))
        Bool -> Bool -> Bool
|| Bool -> Bool
not (Names -> Stm Kernels -> Bool
invariantTo Names
must_be_inlined Stm Kernels
stm)
    (Stms Kernels
recomputed_variant_prestms, Stms Kernels
precomputed_variant_prestms) =
      (Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm Kernels -> Bool
recompute Stms Kernels
variant_prestms

-- Anything that is variant to the "private" names should be
-- considered thread-local.
injectPrelude ::
  SegSpace ->
  VarianceTable ->
  Stms Kernels ->
  Names ->
  (Stms Kernels, Tiling, TiledBody) ->
  (Stms Kernels, Tiling, TiledBody)
injectPrelude :: SegSpace
-> AliasTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space AliasTable
variance Stms Kernels
prestms Names
used (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) =
  (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody')
  where
    tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = do
      let nontiled :: (VName, SubExp) -> Bool
nontiled = ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling))
          private' :: Names
private' =
            Names
private
              Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst (((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, SubExp) -> Bool
nontiled ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space))
          ( Stms Kernels
invariant_prestms,
            Stms Kernels
precomputed_variant_prestms,
            Stms Kernels
recomputed_variant_prestms
            ) =
              AliasTable
-> Stms Kernels
-> Names
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude AliasTable
variance Stms Kernels
prestms Names
private' Names
used

      Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
invariant_prestms

      let live_set :: [VName]
live_set =
            Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
              Stms Kernels -> Names -> Names
forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
precomputed_variant_prestms (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
                Names
used Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
recomputed_variant_prestms
      [VName]
prelude_arrs <-
        Stms Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
precomputed_variant_prestms (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
          Tiling
-> PrivStms
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms Kernels
precomputed_variant_prestms [VName]
live_set

      let prelude_privstms :: PrivStms
prelude_privstms =
            Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
              [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set

      TiledBody
tiledBody Names
private' (PrivStms
prelude_privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms)

tileDoLoop ::
  SegSpace ->
  VarianceTable ->
  Stms Kernels ->
  Names ->
  (Stms Kernels, Tiling, TiledBody) ->
  [Type] ->
  Pattern Kernels ->
  StmAux (ExpDec Kernels) ->
  [(FParam Kernels, SubExp)] ->
  VName ->
  IntType ->
  SubExp ->
  Stms Kernels ->
  Result ->
  TileM (Stms Kernels, Tiling, TiledBody)
tileDoLoop :: SegSpace
-> AliasTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> [Type]
-> Pattern Kernels
-> StmAux (ExpDec Kernels)
-> [(FParam Kernels, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileDoLoop SegSpace
initial_space AliasTable
variance Stms Kernels
prestms Names
used_in_body (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) [Type]
res_ts Pattern Kernels
pat StmAux (ExpDec Kernels)
aux [(FParam Kernels, SubExp)]
merge VName
i IntType
it SubExp
bound Stms Kernels
poststms Result
poststms_res = do
  let prestms_used :: Names
prestms_used = Names
used_in_body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
poststms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
poststms_res
      ( Stms Kernels
invariant_prestms,
        Stms Kernels
precomputed_variant_prestms,
        Stms Kernels
recomputed_variant_prestms
        ) =
          AliasTable
-> Stms Kernels
-> Names
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude AliasTable
variance Stms Kernels
prestms Names
tiled_kdims Names
prestms_used

  let ([Param (TypeBase Shape Uniqueness)]
mergeparams, Result
mergeinits) = [(Param (TypeBase Shape Uniqueness), SubExp)]
-> ([Param (TypeBase Shape Uniqueness)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge

      -- Expand the loop merge parameters to be arrays.
      tileDim :: TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
t = TypeBase Shape Uniqueness
-> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t (Tiling -> Shape
tilingTileShape Tiling
tiling) (Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> TypeBase Shape Uniqueness
forall a b. (a -> b) -> a -> b
$ TypeBase Shape Uniqueness -> Uniqueness
forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness TypeBase Shape Uniqueness
t

      merge_scope :: Scope Kernels
merge_scope = VName -> NameInfo Kernels -> Scope Kernels -> Scope Kernels
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo Kernels
forall lore. IntType -> NameInfo lore
IndexName IntType
it) (Scope Kernels -> Scope Kernels) -> Scope Kernels -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams

      tiledBody' :: TiledBody
tiledBody' Names
private PrivStms
privstms = Scope Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
host_stms Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope Kernels
merge_scope) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
        Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
invariant_prestms

        let live_set :: [VName]
live_set =
              Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
                Stms Kernels -> Names -> Names
forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
precomputed_variant_prestms (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
                  Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
recomputed_variant_prestms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
prestms_used

        [VName]
prelude_arrs <-
          Stms Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
precomputed_variant_prestms (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
            Tiling
-> PrivStms
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms Kernels
precomputed_variant_prestms [VName]
live_set

        [Param (TypeBase Shape Uniqueness)]
mergeparams' <- [Param (TypeBase Shape Uniqueness)]
-> (Param (TypeBase Shape Uniqueness)
    -> BinderT
         Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BinderT
     Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape Uniqueness)]
mergeparams ((Param (TypeBase Shape Uniqueness)
  -> BinderT
       Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
 -> BinderT
      Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)])
-> (Param (TypeBase Shape Uniqueness)
    -> BinderT
         Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness)))
-> BinderT
     Kernels (State VNameSource) [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ \(Param VName
pname TypeBase Shape Uniqueness
pt) ->
          VName
-> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness)
forall dec. VName -> dec -> Param dec
Param (VName
 -> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> BinderT Kernels (State VNameSource) VName
-> BinderT
     Kernels
     (State VNameSource)
     (TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
pname String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group") BinderT
  Kernels
  (State VNameSource)
  (TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> BinderT Kernels (State VNameSource) (TypeBase Shape Uniqueness)
-> BinderT
     Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeBase Shape Uniqueness
-> BinderT Kernels (State VNameSource) (TypeBase Shape Uniqueness)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase Shape Uniqueness -> TypeBase Shape Uniqueness
tileDim TypeBase Shape Uniqueness
pt)

        let merge_ts :: [Type]
merge_ts = (Param (TypeBase Shape Uniqueness) -> Type)
-> [Param (TypeBase Shape Uniqueness)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param (TypeBase Shape Uniqueness)]
mergeparams

        let inloop_privstms :: PrivStms
inloop_privstms =
              Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
                [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set

        Result
mergeinit' <-
          ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
            Certificates
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec Kernels)
aux) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
              Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"tiled_loopinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
                \PrimExp VName
in_bounds Slice SubExp
slice ->
                  ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
                    String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"loopinit" PrimExp VName
in_bounds [Type]
merge_ts (BinderT Kernels (State VNameSource) Result
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
                      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
inloop_privstms
                      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
                      Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
mergeinits

        let merge' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
merge' = [Param (TypeBase Shape Uniqueness)]
-> Result -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams' Result
mergeinit'

        let indexMergeParams :: ReadPrelude
indexMergeParams Slice SubExp
slice =
              Scope Kernels
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
mergeparams') (BinderT Kernels (State VNameSource) ()
 -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                [(Param (TypeBase Shape Uniqueness),
  Param (TypeBase Shape Uniqueness))]
-> ((Param (TypeBase Shape Uniqueness),
     Param (TypeBase Shape Uniqueness))
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [(Param (TypeBase Shape Uniqueness),
     Param (TypeBase Shape Uniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
mergeparams [Param (TypeBase Shape Uniqueness)]
mergeparams') (((Param (TypeBase Shape Uniqueness),
   Param (TypeBase Shape Uniqueness))
  -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) ())
-> ((Param (TypeBase Shape Uniqueness),
     Param (TypeBase Shape Uniqueness))
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape Uniqueness)
to, Param (TypeBase Shape Uniqueness)
from) ->
                  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
to] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
                      VName -> Slice SubExp -> BasicOp
Index (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
from) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                        Type -> Slice SubExp -> Slice SubExp
fullSlice (Param (TypeBase Shape Uniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (TypeBase Shape Uniqueness)
from) Slice SubExp
slice

            private' :: Names
private' =
              Names
private Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
mergeparams')

            privstms' :: PrivStms
privstms' =
              Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
forall a. Monoid a => a
mempty ReadPrelude
indexMergeParams PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms

        BodyT Kernels
loopbody' <-
          Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$
            Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
              ([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TiledBody
tiledBody Names
private' PrivStms
privstms'
        [VName]
accs' <-
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"tiled_inside_loop" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
            [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> BodyT Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge' (VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound []) BodyT Kernels
loopbody'

        Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling (PrivStms
privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms) Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts

  (Stms Kernels, Tiling, TiledBody)
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody')
  where
    tiled_kdims :: Names
tiled_kdims =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
        ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$
          ((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling)) ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
            SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space

doPrelude :: Tiling -> PrivStms -> Stms Kernels -> [VName] -> Binder Kernels [VName]
doPrelude :: Tiling
-> PrivStms
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling PrivStms
privstms Stms Kernels
prestms [VName]
prestms_live =
  -- Create a SegMap that takes care of the prelude for every thread.
  Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"prelude" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
    \PrimExp VName
in_bounds Slice SubExp
slice -> do
      [Type]
ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
prestms_live
      ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"pre"
          (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
            (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
in_bounds)
            ( do
                Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
                Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
prestms
                Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
prestms_live
            )
            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
ts)

liveSet :: FreeIn a => Stms Kernels -> a -> Names
liveSet :: forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
stms a
after =
  [VName] -> Names
namesFromList ((Stm Kernels -> [VName]) -> Stms Kernels -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm Kernels -> PatternT Type) -> Stm Kernels -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) Stms Kernels
stms)
    Names -> Names -> Names
`namesIntersection` a -> Names
forall a. FreeIn a => a -> Names
freeIn a
after

tileable ::
  Stm Kernels ->
  Maybe
    ( SubExp,
      [VName],
      (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels)
    )
tileable :: Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm
  | Op (OtherOp (Screma SubExp
w [VName]
arrs ScremaForm Kernels
form)) <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm,
    Just ([Reduce Kernels]
reds, Lambda Kernels
map_lam) <- ScremaForm Kernels -> Maybe ([Reduce Kernels], Lambda Kernels)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm Kernels
form,
    Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
red_nes <- [Reduce Kernels] -> Reduce Kernels
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce Kernels]
reds,
    Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam, -- No mapout arrays.
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs,
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam =
    (SubExp, [VName],
 (Commutativity, Lambda Kernels, Result, Lambda Kernels))
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda Kernels
red_lam, Result
red_nes, Lambda Kernels
map_lam))
  | Bool
otherwise =
    Maybe
  (SubExp, [VName],
   (Commutativity, Lambda Kernels, Result, Lambda Kernels))
forall a. Maybe a
Nothing

-- | We classify the inputs to the tiled loop as whether they are
-- tileable (and with what permutation of the kernel indexes) or not.
-- In practice, we should have at least one tileable array per loop,
-- but this is not enforced in our representation.
data InputArray
  = InputTile [Int] VName
  | InputDontTile VName

tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs :: [InputArray] -> [(VName, [Int])]
tiledInputs = (InputArray -> Maybe (VName, [Int]))
-> [InputArray] -> [(VName, [Int])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputArray -> Maybe (VName, [Int])
f
  where
    f :: InputArray -> Maybe (VName, [Int])
f (InputTile [Int]
perm VName
arr) = (VName, [Int]) -> Maybe (VName, [Int])
forall a. a -> Maybe a
Just (VName
arr, [Int]
perm)
    f InputDontTile {} = Maybe (VName, [Int])
forall a. Maybe a
Nothing

-- | A tile (or an original untiled array).
data InputTile
  = InputTiled [Int] VName
  | InputUntiled VName

-- First VNames are the tiles, second are the untiled.
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles :: [InputArray] -> [VName] -> [InputTile]
inputsToTiles (InputTile [Int]
perm VName
_ : [InputArray]
inputs) (VName
tile : [VName]
tiles) =
  [Int] -> VName -> InputTile
InputTiled [Int]
perm VName
tile InputTile -> [InputTile] -> [InputTile]
forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles (InputDontTile VName
arr : [InputArray]
inputs) [VName]
tiles =
  VName -> InputTile
InputUntiled VName
arr InputTile -> [InputTile] -> [InputTile]
forall a. a -> [a] -> [a]
: [InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs [VName]
tiles
inputsToTiles [InputArray]
_ [VName]
_ = []

-- The atual tile size may be smaller for the last tile, so we have to
-- be careful now.
sliceUntiled ::
  MonadBinder m =>
  VName ->
  SubExp ->
  SubExp ->
  SubExp ->
  m VName
sliceUntiled :: forall (m :: * -> *).
MonadBinder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
full_tile_size SubExp
this_tile_size = do
  Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  SubExp
slice_offset <-
    String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_offset" (ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
full_tile_size)
  let slice :: DimIndex SubExp
slice = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
slice_offset SubExp
this_tile_size (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"untiled_slice" (ExpT (Lore m) -> m VName) -> ExpT (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [DimIndex SubExp
slice]

-- | Statements that we insert directly into every thread-private
-- SegMaps.  This is for things that cannot efficiently be computed
-- once in advance in the prelude SegMap, primarily (exclusively?)
-- array slicing operations.
data PrivStms = PrivStms (Stms Kernels) ReadPrelude

privStms :: Stms Kernels -> PrivStms
privStms :: Stms Kernels -> PrivStms
privStms Stms Kernels
stms = Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
stms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$ BinderT Kernels (State VNameSource) () -> ReadPrelude
forall a b. a -> b -> a
const (BinderT Kernels (State VNameSource) () -> ReadPrelude)
-> BinderT Kernels (State VNameSource) () -> ReadPrelude
forall a b. (a -> b) -> a -> b
$ () -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

addPrivStms :: Slice SubExp -> PrivStms -> Binder Kernels ()
addPrivStms :: Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
local_slice (PrivStms Stms Kernels
stms ReadPrelude
readPrelude) = do
  ReadPrelude
readPrelude Slice SubExp
local_slice
  Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
stms

instance Semigroup PrivStms where
  PrivStms Stms Kernels
stms_x ReadPrelude
readPrelude_x <> :: PrivStms -> PrivStms -> PrivStms
<> PrivStms Stms Kernels
stms_y ReadPrelude
readPrelude_y =
    Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
stms_z ReadPrelude
readPrelude_z
    where
      stms_z :: Stms Kernels
stms_z = Stms Kernels
stms_x Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
stms_y
      readPrelude_z :: ReadPrelude
readPrelude_z Slice SubExp
slice = ReadPrelude
readPrelude_x Slice SubExp
slice BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ReadPrelude
readPrelude_y Slice SubExp
slice

instance Monoid PrivStms where
  mempty :: PrivStms
mempty = Stms Kernels -> PrivStms
privStms Stms Kernels
forall a. Monoid a => a
mempty

type ReadPrelude = Slice SubExp -> Binder Kernels ()

data ProcessTileArgs = ProcessTileArgs
  { ProcessTileArgs -> PrivStms
processPrivStms :: PrivStms,
    ProcessTileArgs -> Commutativity
processComm :: Commutativity,
    ProcessTileArgs -> Lambda Kernels
processRedLam :: Lambda Kernels,
    ProcessTileArgs -> Lambda Kernels
processMapLam :: Lambda Kernels,
    ProcessTileArgs -> [InputTile]
processTiles :: [InputTile],
    ProcessTileArgs -> [VName]
processAcc :: [VName],
    ProcessTileArgs -> SubExp
processTileId :: SubExp
  }

data ResidualTileArgs = ResidualTileArgs
  { ResidualTileArgs -> PrivStms
residualPrivStms :: PrivStms,
    ResidualTileArgs -> Commutativity
residualComm :: Commutativity,
    ResidualTileArgs -> Lambda Kernels
residualRedLam :: Lambda Kernels,
    ResidualTileArgs -> Lambda Kernels
residualMapLam :: Lambda Kernels,
    ResidualTileArgs -> [InputArray]
residualInput :: [InputArray],
    ResidualTileArgs -> [VName]
residualAcc :: [VName],
    ResidualTileArgs -> SubExp
residualInputSize :: SubExp,
    ResidualTileArgs -> SubExp
residualNumWholeTiles :: SubExp
  }

-- | Information about a loop that has been tiled inside a kernel, as
-- well as the kinds of changes that we would then like to perform on
-- the kernel.
data Tiling = Tiling
  { Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap ::
      String ->
      SegLevel ->
      ResultManifest ->
      (PrimExp VName -> Slice SubExp -> Binder Kernels [SubExp]) ->
      Binder Kernels [VName],
    -- The boolean PrimExp indicates whether they are in-bounds.

    Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
tilingReadTile ::
      TileKind ->
      PrivStms ->
      SubExp ->
      [InputArray] ->
      Binder Kernels [InputTile],
    Tiling
-> ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile ::
      ProcessTileArgs ->
      Binder Kernels [VName],
    Tiling
-> ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile ::
      ResidualTileArgs ->
      Binder Kernels [VName],
    Tiling -> VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns :: VName -> Binder Kernels KernelResult,
    Tiling -> SegSpace
tilingSpace :: SegSpace,
    Tiling -> Shape
tilingTileShape :: Shape,
    Tiling -> SegLevel
tilingLevel :: SegLevel,
    Tiling -> Binder Kernels SubExp
tilingNumWholeTiles :: Binder Kernels SubExp
  }

type DoTiling gtids kdims =
  SegLevel -> gtids -> kdims -> SubExp -> Binder Kernels Tiling

scalarLevel :: Tiling -> SegLevel
scalarLevel :: Tiling -> SegLevel
scalarLevel Tiling
tiling =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt
  where
    lvl :: SegLevel
lvl = Tiling -> SegLevel
tilingLevel Tiling
tiling

protectOutOfBounds ::
  String ->
  PrimExp VName ->
  [Type] ->
  Binder Kernels [SubExp] ->
  Binder Kernels [VName]
protectOutOfBounds :: String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
desc PrimExp VName
in_bounds [Type]
ts BinderT Kernels (State VNameSource) Result
m =
  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
in_bounds) (Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) Result
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT Kernels (State VNameSource) Result
m) ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
ts)

postludeGeneric ::
  Tiling ->
  PrivStms ->
  Pattern Kernels ->
  [VName] ->
  Stms Kernels ->
  Result ->
  [Type] ->
  Binder Kernels [VName]
postludeGeneric :: Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts =
  Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"thread_res" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds Slice SubExp
slice -> do
    -- Read our per-thread result from the tiled loop.
    [(VName, VName)]
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat) [VName]
accs') (((VName, VName) -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) ())
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(VName
us, VName
everyone) -> do
      Type
everyone_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
everyone
      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
us] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
everyone (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
everyone_t Slice SubExp
slice

    if Stms Kernels
poststms Stms Kernels -> Stms Kernels -> Bool
forall a. Eq a => a -> a -> Bool
== Stms Kernels
forall a. Monoid a => a
mempty
      then do
        -- The privstms may still be necessary for the result.
        Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
        Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
poststms_res
      else ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"postlude" PrimExp VName
in_bounds [Type]
res_ts (BinderT Kernels (State VNameSource) Result
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
          Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
          Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
poststms
          Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
poststms_res

type TiledBody = Names -> PrivStms -> Binder Kernels [VName]

tileGeneric ::
  DoTiling gtids kdims ->
  SegLevel ->
  [Type] ->
  Pattern Kernels ->
  gtids ->
  kdims ->
  SubExp ->
  (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels) ->
  [InputArray] ->
  Stms Kernels ->
  Result ->
  TileM (Stms Kernels, Tiling, TiledBody)
tileGeneric :: forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [InputArray]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileGeneric DoTiling gtids kdims
doTiling SegLevel
initial_lvl [Type]
res_ts Pattern Kernels
pat gtids
gtids kdims
kdims SubExp
w (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form [InputArray]
inputs Stms Kernels
poststms Result
poststms_res = do
  (Tiling
tiling, Stms Kernels
tiling_stms) <- Binder Kernels Tiling
-> ReaderT
     (Scope Kernels) (State VNameSource) (Tiling, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels Tiling
 -> ReaderT
      (Scope Kernels) (State VNameSource) (Tiling, Stms Kernels))
-> Binder Kernels Tiling
-> ReaderT
     (Scope Kernels) (State VNameSource) (Tiling, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ DoTiling gtids kdims
doTiling SegLevel
initial_lvl gtids
gtids kdims
kdims SubExp
w

  (Stms Kernels, Tiling, TiledBody)
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
tiling_stms, Tiling
tiling, Tiling -> TiledBody
tiledBody Tiling
tiling)
  where
    (Commutativity
red_comm, Lambda Kernels
red_lam, Result
red_nes, Lambda Kernels
map_lam) = (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form

    tiledBody :: Tiling -> Names -> PrivStms -> Binder Kernels [VName]
    tiledBody :: Tiling -> TiledBody
tiledBody Tiling
tiling Names
_private PrivStms
privstms = do
      let tile_shape :: Shape
tile_shape = Tiling -> Shape
tilingTileShape Tiling
tiling

      SubExp
num_whole_tiles <- Tiling -> Binder Kernels SubExp
tilingNumWholeTiles Tiling
tiling

      -- We don't use a Replicate here, because we want to enforce a
      -- scalar memory space.
      [VName]
mergeinits <- Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"mergeinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds Slice SubExp
slice ->
        -- Constant neutral elements (a common case) do not need protection from OOB.
        if Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
red_nes Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
          then Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
red_nes
          else ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
            String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"neutral" PrimExp VName
in_bounds (Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam) (BinderT Kernels (State VNameSource) Result
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
              Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
              Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
red_nes

      [(Param (TypeBase Shape Uniqueness), SubExp)]
merge <- [(Param Type, VName)]
-> ((Param Type, VName)
    -> BinderT
         Kernels
         (State VNameSource)
         (Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
     Kernels
     (State VNameSource)
     [(Param (TypeBase Shape Uniqueness), SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam) [VName]
mergeinits) (((Param Type, VName)
  -> BinderT
       Kernels
       (State VNameSource)
       (Param (TypeBase Shape Uniqueness), SubExp))
 -> BinderT
      Kernels
      (State VNameSource)
      [(Param (TypeBase Shape Uniqueness), SubExp)])
-> ((Param Type, VName)
    -> BinderT
         Kernels
         (State VNameSource)
         (Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
     Kernels
     (State VNameSource)
     [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
mergeinit) ->
        (,)
          (Param (TypeBase Shape Uniqueness)
 -> SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
-> BinderT
     Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
-> BinderT
     Kernels
     (State VNameSource)
     (SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> TypeBase Shape Uniqueness
-> BinderT
     Kernels (State VNameSource) (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam
            (VName -> String
baseString (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_merge")
            (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p Type -> Shape -> Type
`arrayOfShape` Shape
tile_shape Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique)
          BinderT
  Kernels
  (State VNameSource)
  (SubExp -> (Param (TypeBase Shape Uniqueness), SubExp))
-> Binder Kernels SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Param (TypeBase Shape Uniqueness), SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> Binder Kernels SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
mergeinit)

      VName
tile_id <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tile_id"
      let loopform :: LoopForm Kernels
loopform = VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
tile_id IntType
Int64 SubExp
num_whole_tiles []
      BodyT Kernels
loopbody <- BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT Kernels -> Binder Kernels (BodyT Kernels))
-> (Binder Kernels (BodyT Kernels)
    -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels)
-> Binder Kernels (BodyT Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$
        LoopForm Kernels
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm Kernels
loopform (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$
          Scope Kernels
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param (TypeBase Shape Uniqueness)] -> Scope Kernels)
-> [Param (TypeBase Shape Uniqueness)] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
merge) (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ do
            -- Collectively read a tile.
            [InputTile]
tile <- Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
tilingReadTile Tiling
tiling TileKind
TilePartial PrivStms
privstms (VName -> SubExp
Var VName
tile_id) [InputArray]
inputs

            -- Now each thread performs a traversal of the tile and
            -- updates its accumulator.
            let accs :: [VName]
accs =
                  ((Param (TypeBase Shape Uniqueness), SubExp) -> VName)
-> [(Param (TypeBase Shape Uniqueness), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape Uniqueness) -> VName)
-> ((Param (TypeBase Shape Uniqueness), SubExp)
    -> Param (TypeBase Shape Uniqueness))
-> (Param (TypeBase Shape Uniqueness), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst) [(Param (TypeBase Shape Uniqueness), SubExp)]
merge
                tile_args :: ProcessTileArgs
tile_args =
                  PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [InputTile]
tile [VName]
accs (VName -> SubExp
Var VName
tile_id)
            Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tiling
-> ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile Tiling
tiling ProcessTileArgs
tile_args

      [VName]
accs <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"accs" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> BodyT Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam Kernels, SubExp)]
merge LoopForm Kernels
loopform BodyT Kernels
loopbody

      -- We possibly have to traverse a residual tile.
      Lambda Kernels
red_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
red_lam
      Lambda Kernels
map_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
map_lam
      let residual_args :: ResidualTileArgs
residual_args =
            PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputArray]
-> [VName]
-> SubExp
-> SubExp
-> ResidualTileArgs
ResidualTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam' Lambda Kernels
map_lam' [InputArray]
inputs [VName]
accs SubExp
w SubExp
num_whole_tiles
      [VName]
accs' <- Tiling
-> ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile Tiling
tiling ResidualTileArgs
residual_args

      -- Create a SegMap that takes care of the postlude for every thread.
      Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts

data TileKind = TilePartial | TileFull

mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prestms_live_arrs [VName]
prestms_live Slice SubExp
slice =
  ([()] -> ())
-> BinderT Kernels (State VNameSource) [()]
-> BinderT Kernels (State VNameSource) ()
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [()] -> ()
forall a. Monoid a => [a] -> a
mconcat (BinderT Kernels (State VNameSource) [()]
 -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
    [(VName, VName)]
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prestms_live_arrs [VName]
prestms_live) (((VName, VName) -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) [()])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
forall a b. (a -> b) -> a -> b
$ \(VName
arr, VName
v) -> do
      Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t Slice SubExp
slice

tileReturns :: [(VName, SubExp)] -> [(SubExp, SubExp)] -> VName -> Binder Kernels KernelResult
tileReturns :: [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp, SubExp)]
dims VName
arr = do
  let unit_dims :: Result
unit_dims = Int -> SubExp -> Result
forall a. Int -> a -> [a]
replicate ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
dims_on_top) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  VName
arr' <-
    if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top
      then VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
      else do
        Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
        let new_shape :: Result
new_shape = Result
unit_dims Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr) (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew Result
new_shape) VName
arr
  let tile_dims :: [(SubExp, SubExp)]
tile_dims = Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top) Result
unit_dims [(SubExp, SubExp)] -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(SubExp, SubExp)]
dims
  KernelResult -> BinderT Kernels (State VNameSource) KernelResult
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> BinderT Kernels (State VNameSource) KernelResult)
-> KernelResult -> BinderT Kernels (State VNameSource) KernelResult
forall a b. (a -> b) -> a -> b
$ [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns [(SubExp, SubExp)]
tile_dims VName
arr'

is1DTileable :: VName -> M.Map VName Names -> VName -> InputArray
is1DTileable :: VName -> AliasTable -> VName -> InputArray
is1DTileable VName
gtid AliasTable
variance VName
arr
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
gtid (Names -> Bool) -> Names -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr AliasTable
variance =
    [Int] -> VName -> InputArray
InputTile [Int
0] VName
arr
  | Bool
otherwise =
    VName -> InputArray
InputDontTile VName
arr

segMap1D ::
  String ->
  SegLevel ->
  ResultManifest ->
  (VName -> Binder Kernels [SubExp]) ->
  Binder Kernels [VName]
segMap1D :: String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
desc SegLevel
lvl ResultManifest
manifest VName -> BinderT Kernels (State VNameSource) Result
f = do
  VName
ltid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid"
  VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)]

  (([Type]
ts, Result
res), Stms Kernels
stms) <- Binder Kernels ([Type], Result)
-> BinderT
     Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels ([Type], Result)
 -> BinderT
      Kernels (State VNameSource) (([Type], Result), Stms Kernels))
-> Binder Kernels ([Type], Result)
-> BinderT
     Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- VName -> BinderT Kernels (State VNameSource) Result
f VName
ltid
    [Type]
ts <- (SubExp -> BinderT Kernels (State VNameSource) Type)
-> Result -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType Result
res
    ([Type], Result) -> Binder Kernels ([Type], Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, Result
res)
  Body BodyDec Kernels
_ Stms Kernels
stms' Result
res' <- BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT Kernels -> Binder Kernels (BodyT Kernels))
-> BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
stms Result
res

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
    Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
      SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
        SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts (KernelBody Kernels -> SegOp SegLevel Kernels)
-> KernelBody Kernels -> SegOp SegLevel Kernels
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms' ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest) Result
res'

reconstructGtids1D ::
  Count GroupSize SubExp ->
  VName ->
  VName ->
  VName ->
  Binder Kernels ()
reconstructGtids1D :: Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid =
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid]
    (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)

readTile1D ::
  SubExp ->
  VName ->
  VName ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  TileKind ->
  PrivStms ->
  SubExp ->
  [InputArray] ->
  Binder Kernels [InputTile]
readTile1D :: SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
  ([VName] -> [InputTile])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [InputTile]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
    (BinderT Kernels (State VNameSource) [VName]
 -> Binder Kernels [InputTile])
-> ((VName -> BinderT Kernels (State VNameSource) Result)
    -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
"full_tile" SegLevel
lvl ResultManifest
ResultNoSimplify
    ((VName -> BinderT Kernels (State VNameSource) Result)
 -> Binder Kernels [InputTile])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
      SubExp
j <-
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j"
          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)

      Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms

      let arrs :: [VName]
arrs = ((VName, [Int]) -> VName) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, [Int]) -> VName
forall a b. (a, b) -> a
fst ([(VName, [Int])] -> [VName]) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> a -> b
$ [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs
      [Type]
arr_ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrs
      let tile_ts :: [Type]
tile_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
arr_ts
          w :: SubExp
w = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
arr_ts

      let readTileElem :: VName -> BinderT Kernels (State VNameSource) VName
readTileElem VName
arr =
            -- No need for fullSlice because we are tiling only prims.
            String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"tile_elem" (BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
j])
      ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        case TileKind
kind of
          TileKind
TilePartial ->
            String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"pre"
              (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
                (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
j TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w)
                (Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) Result
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Binder Kernels SubExp)
-> [VName] -> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> SubExp)
-> BinderT Kernels (State VNameSource) VName
-> Binder Kernels SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (BinderT Kernels (State VNameSource) VName
 -> Binder Kernels SubExp)
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> VName
-> Binder Kernels SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> BinderT Kernels (State VNameSource) VName
readTileElem) [VName]
arrs)
                ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
tile_ts)
          TileKind
TileFull ->
            (VName -> BinderT Kernels (State VNameSource) VName)
-> [VName] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) VName
readTileElem [VName]
arrs
  where
    lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt

processTile1D ::
  VName ->
  VName ->
  SubExp ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ProcessTileArgs ->
  Binder Kernels [VName]
processTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args = do
  let red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
      privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
      map_lam :: Lambda Kernels
map_lam = ProcessTileArgs -> Lambda Kernels
processMapLam ProcessTileArgs
tile_args
      red_lam :: Lambda Kernels
red_lam = ProcessTileArgs -> Lambda Kernels
processRedLam ProcessTileArgs
tile_args
      tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
      tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args
      accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args

  String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
"acc" SegLevel
lvl ResultManifest
ResultPrivate ((VName -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
    Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
    Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms

    -- We replace the neutral elements with the accumulators (this is
    -- OK because the parallel semantics are not used after this
    -- point).
    Result
thread_accs <- [VName]
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> Binder Kernels SubExp)
 -> BinderT Kernels (State VNameSource) Result)
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"acc" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid]
    let sliceTile :: InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile (InputTiled [Int]
_ VName
arr) =
          VName -> BinderT Kernels (State VNameSource) VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
        sliceTile (InputUntiled VName
arr) =
          VName
-> SubExp
-> SubExp
-> SubExp
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
tile_size

    [VName]
tiles' <- (InputTile -> BinderT Kernels (State VNameSource) VName)
-> [InputTile] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile [InputTile]
tiles

    let form' :: ScremaForm Kernels
form' = [Reduce Kernels] -> Lambda Kernels -> ScremaForm Kernels
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Commutativity -> Lambda Kernels -> Result -> Reduce Kernels
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
thread_accs] Lambda Kernels
map_lam
    ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc"
        (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
          (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim)
          ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm Kernels -> SOAC Kernels
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
tile_size [VName]
tiles' ScremaForm Kernels
form'])
          (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
thread_accs)
  where
    lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt

processResidualTile1D ::
  VName ->
  VName ->
  SubExp ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ResidualTileArgs ->
  Binder Kernels [VName]
processResidualTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ResidualTileArgs
args = do
  -- The number of residual elements that are not covered by
  -- the whole tiles.
  SubExp
residual_input <-
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"residual_input" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc_after_residual"
    (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
      (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
      (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
      (SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input)
  where
    red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
    map_lam :: Lambda Kernels
map_lam = ResidualTileArgs -> Lambda Kernels
residualMapLam ResidualTileArgs
args
    red_lam :: Lambda Kernels
red_lam = ResidualTileArgs -> Lambda Kernels
residualRedLam ResidualTileArgs
args
    privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
    inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
    accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
    num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
    w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args

    nonemptyTile :: SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input = Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ do
      -- Collectively construct a tile.  Threads that are out-of-bounds
      -- provide a blank dummy value.
      [InputTile]
full_tiles <-
        SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile1D
          SubExp
tile_size
          VName
gid
          VName
gtid
          Count NumGroups SubExp
num_groups
          Count GroupSize SubExp
group_size
          TileKind
TilePartial
          PrivStms
privstms
          SubExp
num_whole_tiles
          [InputArray]
inputs

      let sliceTile :: InputTile -> BinderT Kernels (State VNameSource) InputTile
sliceTile (InputUntiled VName
arr) =
            InputTile -> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> InputTile -> BinderT Kernels (State VNameSource) InputTile
forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr
          sliceTile (InputTiled [Int]
perm VName
tile) = do
            let slice :: DimIndex SubExp
slice =
                  SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
            [Int] -> VName -> InputTile
InputTiled [Int]
perm
              (VName -> InputTile)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"partial_tile" (BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile [DimIndex SubExp
slice])

      [InputTile]
tiles <- (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> [InputTile] -> Binder Kernels [InputTile]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BinderT Kernels (State VNameSource) InputTile
sliceTile [InputTile]
full_tiles

      -- Now each thread performs a traversal of the tile and
      -- updates its accumulator.
      let tile_args :: ProcessTileArgs
tile_args =
            PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles
      Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
        ([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
residual_input Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args

tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d [(VName, SubExp)]
dims_on_top SegLevel
initial_lvl VName
gtid SubExp
kdim SubExp
w = do
  VName
gid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid"
  VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"

  (SegLevel
lvl, SegSpace
space) <-
    if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top
      then
        (SegLevel, SegSpace)
-> BinderT Kernels (State VNameSource) (SegLevel, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl) (SegVirt -> SegLevel) -> SegVirt -> SegLevel
forall a b. (a -> b) -> a -> b
$ SegLevel -> SegVirt
segVirt SegLevel
initial_lvl,
            VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat [(VName
gid, Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl)]
          )
      else do
        SubExp
group_size <-
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"computed_group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int64) (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl)) SubExp
kdim

        -- How many groups we need to exhaust the innermost dimension.
        SubExp
ldim <-
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"ldim" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim SubExp
group_size

        SubExp
num_groups <-
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"computed_num_groups"
            (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
ldim (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)

        (SegLevel, SegSpace)
-> BinderT Kernels (State VNameSource) (SegLevel, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt,
            VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid, SubExp
ldim)]
          )
  let tile_size :: SubExp
tile_size = Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

  Tiling -> Binder Kernels Tiling
forall (m :: * -> *) a. Monad m => a -> m a
return
    Tiling :: (String
 -> SegLevel
 -> ResultManifest
 -> (PrimExp VName
     -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (TileKind
    -> PrivStms
    -> SubExp
    -> [InputArray]
    -> Binder Kernels [InputTile])
-> (ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName])
-> (ResidualTileArgs
    -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> SegSpace
-> Shape
-> SegLevel
-> Binder Kernels SubExp
-> Tiling
Tiling
      { tilingSegMap :: String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap = \String
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f -> String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
desc SegLevel
lvl' ResultManifest
manifest ((VName -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
          [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid]
            (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid)
          PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f (TPrimExp Bool VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> PrimExp VName)
-> TPrimExp Bool VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid],
        tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Binder Kernels [InputTile]
tilingReadTile =
          SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile1D SubExp
tile_size VName
gid VName
gtid (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessTile :: ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile =
          VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessResidualTile :: ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile =
          VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingTileReturns :: VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim, SubExp
tile_size)],
        tilingTileShape :: Shape
tilingTileShape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size],
        tilingNumWholeTiles :: Binder Kernels SubExp
tilingNumWholeTiles =
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_whole_tiles" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
        tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl,
        tilingSpace :: SegSpace
tilingSpace = SegSpace
space
      }

invariantToOneOfTwoInnerDims ::
  Names ->
  M.Map VName Names ->
  [VName] ->
  VName ->
  Maybe InputArray
invariantToOneOfTwoInnerDims :: Names -> AliasTable -> [VName] -> VName -> Maybe InputArray
invariantToOneOfTwoInnerDims Names
branch_variant AliasTable
variance [VName]
dims VName
arr = do
  VName
j : VName
i : [VName]
_ <- [VName] -> Maybe [VName]
forall a. a -> Maybe a
Just ([VName] -> Maybe [VName]) -> [VName] -> Maybe [VName]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
dims
  let variant_to :: Names
variant_to = Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr AliasTable
variance
      branch_invariant :: Bool
branch_invariant = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
j Names
branch_variant Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
  if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName
j VName -> Names -> Bool
`nameIn` Names
variant_to)
    then InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
0, Int
1] VName
arr
    else
      if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName
i VName -> Names -> Bool
`nameIn` Names
variant_to)
        then InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> InputArray
InputTile [Int
1, Int
0] VName
arr
        else InputArray -> Maybe InputArray
forall a. a -> Maybe a
Just (InputArray -> Maybe InputArray) -> InputArray -> Maybe InputArray
forall a b. (a -> b) -> a -> b
$ VName -> InputArray
InputDontTile VName
arr

-- Reconstruct the original gtids from group and local IDs.
reconstructGtids2D ::
  SubExp ->
  (VName, VName) ->
  (VName, VName) ->
  (VName, VName) ->
  Binder Kernels ()
reconstructGtids2D :: SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y) = do
  -- Reconstruct the original gtids from gid_x/gid_y and ltid_x/ltid_y.
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_x]
    (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
gtid_y]
    (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gid_y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)

readTile2D ::
  (SubExp, SubExp) ->
  (VName, VName) ->
  (VName, VName) ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  TileKind ->
  PrivStms ->
  SubExp ->
  [InputArray] ->
  Binder Kernels [InputTile]
readTile2D :: (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TileKind
kind PrivStms
privstms SubExp
tile_id [InputArray]
inputs =
  ([VName] -> [InputTile])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [InputTile]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([InputArray] -> [VName] -> [InputTile]
inputsToTiles [InputArray]
inputs)
    (BinderT Kernels (State VNameSource) [VName]
 -> Binder Kernels [InputTile])
-> (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
    -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D
      String
"full_tile"
      (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirtFull)
      ResultManifest
ResultNoSimplify
      (SubExp
tile_size, SubExp
tile_size)
    (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
 -> Binder Kernels [InputTile])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> Binder Kernels [InputTile]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
      SubExp
i <-
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i"
          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_x)
      SubExp
j <-
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j"
          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
ltid_y)

      SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms

      let arrs_and_perms :: [(VName, [Int])]
arrs_and_perms = [InputArray] -> [(VName, [Int])]
tiledInputs [InputArray]
inputs

          readTileElem :: (VName, [Int]) -> BinderT Kernels (State VNameSource) VName
readTileElem (VName
arr, [Int]
perm) =
            -- No need for fullSlice because we are tiling only prims.
            String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp
              String
"tile_elem"
              ( BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index
                    VName
arr
                    [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ Result -> SubExp
forall a. [a] -> a
last (Result -> SubExp) -> Result -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]]
              )

          readTileElemIfInBounds :: (VName, [Int])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
readTileElemIfInBounds (VName
arr, [Int]
perm) = do
            Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
            let tile_t :: Type
tile_t = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
arr_t
                w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t
                idx :: SubExp
idx = Result -> SubExp
forall a. [a] -> a
last (Result -> SubExp) -> Result -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i, SubExp
j]
                othercheck :: TPrimExp Bool VName
othercheck =
                  [TPrimExp Bool VName] -> TPrimExp Bool VName
forall a. [a] -> a
last ([TPrimExp Bool VName] -> TPrimExp Bool VName)
-> [TPrimExp Bool VName] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$
                    [Int] -> [TPrimExp Bool VName] -> [TPrimExp Bool VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape
                      [Int]
perm
                      [ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y,
                        VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x
                      ]
            BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
              (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Bool VName
othercheck)
              ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]])
              ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank Type
tile_t])

      ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        case TileKind
kind of
          TileKind
TilePartial ->
            ((VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
-> [(VName, [Int])] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"pre" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> ((VName, [Int])
    -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> (VName, [Int])
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (VName, [Int])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
readTileElemIfInBounds) [(VName, [Int])]
arrs_and_perms
          TileKind
TileFull ->
            ((VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
-> [(VName, [Int])] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, [Int]) -> BinderT Kernels (State VNameSource) VName
readTileElem [(VName, [Int])]
arrs_and_perms

findTileSize :: HasScope lore m => [InputTile] -> m SubExp
findTileSize :: forall lore (m :: * -> *).
HasScope lore m =>
[InputTile] -> m SubExp
findTileSize [InputTile]
tiles =
  case (InputTile -> Maybe VName) -> [InputTile] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InputTile -> Maybe VName
isTiled [InputTile]
tiles of
    VName
v : [VName]
_ -> Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> SubExp) -> m Type -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
    [] -> SubExp -> m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
  where
    isTiled :: InputTile -> Maybe VName
isTiled InputUntiled {} = Maybe VName
forall a. Maybe a
Nothing
    isTiled (InputTiled [Int]
_ VName
tile) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
tile

processTile2D ::
  (VName, VName) ->
  (VName, VName) ->
  (SubExp, SubExp) ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ProcessTileArgs ->
  Binder Kernels [VName]
processTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size ProcessTileArgs
tile_args = do
  let privstms :: PrivStms
privstms = ProcessTileArgs -> PrivStms
processPrivStms ProcessTileArgs
tile_args
      red_comm :: Commutativity
red_comm = ProcessTileArgs -> Commutativity
processComm ProcessTileArgs
tile_args
      red_lam :: Lambda Kernels
red_lam = ProcessTileArgs -> Lambda Kernels
processRedLam ProcessTileArgs
tile_args
      map_lam :: Lambda Kernels
map_lam = ProcessTileArgs -> Lambda Kernels
processMapLam ProcessTileArgs
tile_args
      tiles :: [InputTile]
tiles = ProcessTileArgs -> [InputTile]
processTiles ProcessTileArgs
tile_args
      accs :: [VName]
accs = ProcessTileArgs -> [VName]
processAcc ProcessTileArgs
tile_args
      tile_id :: SubExp
tile_id = ProcessTileArgs -> SubExp
processTileId ProcessTileArgs
tile_args

  -- Might be truncated in case of a partial tile.
  SubExp
actual_tile_size <- [InputTile] -> Binder Kernels SubExp
forall lore (m :: * -> *).
HasScope lore m =>
[InputTile] -> m SubExp
findTileSize [InputTile]
tiles

  String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D
    String
"acc"
    (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirtFull)
    ResultManifest
ResultPrivate
    (SubExp
tile_size, SubExp
tile_size)
    (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
      SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)

      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms

      -- We replace the neutral elements with the accumulators (this is
      -- OK because the parallel semantics are not used after this
      -- point).
      Result
thread_accs <- [VName]
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> Binder Kernels SubExp)
 -> BinderT Kernels (State VNameSource) Result)
-> (VName -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"acc" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y]
      let form' :: ScremaForm Kernels
form' = [Reduce Kernels] -> Lambda Kernels -> ScremaForm Kernels
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Commutativity -> Lambda Kernels -> Result -> Reduce Kernels
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
thread_accs] Lambda Kernels
map_lam

          sliceTile :: InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile (InputUntiled VName
arr) =
            VName
-> SubExp
-> SubExp
-> SubExp
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
VName -> SubExp -> SubExp -> SubExp -> m VName
sliceUntiled VName
arr SubExp
tile_id SubExp
tile_size SubExp
actual_tile_size
          sliceTile (InputTiled [Int]
perm VName
tile) = do
            Type
tile_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
tile
            let idx :: DimIndex SubExp
idx = SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [Int] -> [VName] -> [VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [VName
ltid_x, VName
ltid_y]
            String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"tile" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
              BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Int -> Slice SubExp -> Slice SubExp
sliceAt Type
tile_t ([Int] -> Int
forall a. [a] -> a
head [Int]
perm) [DimIndex SubExp
idx]

      [VName]
tiles' <- (InputTile -> BinderT Kernels (State VNameSource) VName)
-> [InputTile] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM InputTile -> BinderT Kernels (State VNameSource) VName
sliceTile [InputTile]
tiles

      ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc"
          (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
            ( TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
            )
            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm Kernels -> SOAC Kernels
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
actual_tile_size [VName]
tiles' ScremaForm Kernels
form'])
            (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
thread_accs)

processResidualTile2D ::
  (VName, VName) ->
  (VName, VName) ->
  (SubExp, SubExp) ->
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  ResidualTileArgs ->
  Binder Kernels [VName]
processResidualTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile2D
  (VName, VName)
gids
  (VName, VName)
gtids
  (SubExp, SubExp)
kdims
  SubExp
tile_size
  Count NumGroups SubExp
num_groups
  Count GroupSize SubExp
group_size
  ResidualTileArgs
args = do
    -- The number of residual elements that are not covered by
    -- the whole tiles.
    SubExp
residual_input <-
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"residual_input" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size

    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc_after_residual"
      (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf
        (TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TPrimExp Bool VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> TPrimExp Bool VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
residual_input TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
        (SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input)
    where
      privstms :: PrivStms
privstms = ResidualTileArgs -> PrivStms
residualPrivStms ResidualTileArgs
args
      red_comm :: Commutativity
red_comm = ResidualTileArgs -> Commutativity
residualComm ResidualTileArgs
args
      red_lam :: Lambda Kernels
red_lam = ResidualTileArgs -> Lambda Kernels
residualRedLam ResidualTileArgs
args
      map_lam :: Lambda Kernels
map_lam = ResidualTileArgs -> Lambda Kernels
residualMapLam ResidualTileArgs
args
      accs :: [VName]
accs = ResidualTileArgs -> [VName]
residualAcc ResidualTileArgs
args
      inputs :: [InputArray]
inputs = ResidualTileArgs -> [InputArray]
residualInput ResidualTileArgs
args
      num_whole_tiles :: SubExp
num_whole_tiles = ResidualTileArgs -> SubExp
residualNumWholeTiles ResidualTileArgs
args
      w :: SubExp
w = ResidualTileArgs -> SubExp
residualInputSize ResidualTileArgs
args

      nonemptyTile :: SubExp -> Binder Kernels (BodyT Kernels)
nonemptyTile SubExp
residual_input = BodyT Kernels -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT Kernels -> Binder Kernels (BodyT Kernels))
-> (Binder Kernels (BodyT Kernels)
    -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels)
-> Binder Kernels (BodyT Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels))
-> Binder Kernels (BodyT Kernels) -> Binder Kernels (BodyT Kernels)
forall a b. (a -> b) -> a -> b
$ do
        -- Collectively construct a tile.  Threads that are out-of-bounds
        -- provide a blank dummy value.
        [InputTile]
full_tile <-
          (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile2D
            (SubExp, SubExp)
kdims
            (VName, VName)
gtids
            (VName, VName)
gids
            SubExp
tile_size
            Count NumGroups SubExp
num_groups
            Count GroupSize SubExp
group_size
            TileKind
TilePartial
            PrivStms
privstms
            SubExp
num_whole_tiles
            [InputArray]
inputs

        let slice :: DimIndex SubExp
slice =
              SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
        [InputTile]
tiles <- [InputTile]
-> (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> Binder Kernels [InputTile]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [InputTile]
full_tile ((InputTile -> BinderT Kernels (State VNameSource) InputTile)
 -> Binder Kernels [InputTile])
-> (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> Binder Kernels [InputTile]
forall a b. (a -> b) -> a -> b
$ \case
          InputTiled [Int]
perm VName
tile' ->
            [Int] -> VName -> InputTile
InputTiled [Int]
perm
              (VName -> InputTile)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"partial_tile" (BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile' [DimIndex SubExp
slice, DimIndex SubExp
slice])
          InputUntiled VName
arr ->
            InputTile -> BinderT Kernels (State VNameSource) InputTile
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InputTile -> BinderT Kernels (State VNameSource) InputTile)
-> InputTile -> BinderT Kernels (State VNameSource) InputTile
forall a b. (a -> b) -> a -> b
$ VName -> InputTile
InputUntiled VName
arr

        let tile_args :: ProcessTileArgs
tile_args =
              PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [InputTile]
-> [VName]
-> SubExp
-> ProcessTileArgs
ProcessTileArgs PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [InputTile]
tiles [VName]
accs SubExp
num_whole_tiles

        -- Now each thread performs a traversal of the tile and
        -- updates its accumulator.
        Result -> BodyT Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> BodyT Kernels)
-> ([VName] -> Result) -> [VName] -> BodyT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var
          ([VName] -> BodyT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile2D
            (VName, VName)
gids
            (VName, VName)
gtids
            (SubExp, SubExp)
kdims
            SubExp
tile_size
            Count NumGroups SubExp
num_groups
            Count GroupSize SubExp
group_size
            ProcessTileArgs
tile_args

tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d [(VName, SubExp)]
dims_on_top SegLevel
_initial_lvl (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
w = do
  VName
gid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_x"
  VName
gid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_y"

  Name
tile_size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tile_size"
  SubExp
tile_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"tile_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tile_size_key SizeClass
SizeTile
  SubExp
group_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
tile_size SubExp
tile_size

  SubExp
num_groups_x <-
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_x" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_x SubExp
tile_size
  SubExp
num_groups_y <-
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_y" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
kdim_y SubExp
tile_size

  SubExp
num_groups <-
    String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_top"
      (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp
        (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
        SubExp
num_groups_x
        (SubExp
num_groups_y SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)

  VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"
  let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirtFull
      space :: SegSpace
space =
        VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_x, SubExp
num_groups_x), (VName
gid_y, SubExp
num_groups_y)]

  Tiling -> Binder Kernels Tiling
forall (m :: * -> *) a. Monad m => a -> m a
return
    Tiling :: (String
 -> SegLevel
 -> ResultManifest
 -> (PrimExp VName
     -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (TileKind
    -> PrivStms
    -> SubExp
    -> [InputArray]
    -> Binder Kernels [InputTile])
-> (ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName])
-> (ResidualTileArgs
    -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> SegSpace
-> Shape
-> SegLevel
-> Binder Kernels SubExp
-> Tiling
Tiling
      { tilingSegMap :: String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap = \String
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f ->
          String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D String
desc SegLevel
lvl' ResultManifest
manifest (SubExp
tile_size, SubExp
tile_size) (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
            SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
            PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f
              ( TPrimExp Bool VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Bool VName -> PrimExp VName)
-> TPrimExp Bool VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                  VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_x
                    TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
gtid_y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
kdim_y
              )
              [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y],
        tilingReadTile :: TileKind
-> PrivStms -> SubExp -> [InputArray] -> Binder Kernels [InputTile]
tilingReadTile = (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [InputArray]
-> Binder Kernels [InputTile]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessTile :: ProcessTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ProcessTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingProcessResidualTile :: ResidualTileArgs -> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> ResidualTileArgs
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl),
        tilingTileReturns :: VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim_x, SubExp
tile_size), (SubExp
kdim_y, SubExp
tile_size)],
        tilingTileShape :: Shape
tilingTileShape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size, SubExp
tile_size],
        tilingNumWholeTiles :: Binder Kernels SubExp
tilingNumWholeTiles =
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_whole_tiles" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w SubExp
tile_size,
        tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl,
        tilingSpace :: SegSpace
tilingSpace = SegSpace
space
      }