{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | Perform a restricted form of loop tiling within SegMaps.  We only
-- tile primitive types, to avoid excessive local memory use.
module Futhark.Optimise.TileLoops
       ( tileLoops )
       where

import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Sequence as Seq
import qualified Data.Map.Strict as M
import Data.List (foldl')

import Prelude hiding (quot)

import Futhark.MonadFreshNames
import Futhark.Representation.Kernels
import Futhark.Transform.Rename
import Futhark.Pass
import Futhark.Tools

tileLoops :: Pass Kernels Kernels
tileLoops :: Pass Kernels Kernels
tileLoops = String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"tile loops" String
"Tile stream loops inside kernels" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
            \(Prog Stms Kernels
consts [FunDef Kernels]
funs) ->
              Stms Kernels -> [FunDef Kernels] -> Prog Kernels
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog Stms Kernels
consts ([FunDef Kernels] -> Prog Kernels)
-> PassM [FunDef Kernels] -> PassM (Prog Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FunDef Kernels -> PassM (FunDef Kernels))
-> [FunDef Kernels] -> PassM [FunDef Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunDef Kernels -> PassM (FunDef Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
FunDef Kernels -> m (FunDef Kernels)
optimiseFunDef [FunDef Kernels]
funs

optimiseFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels)
optimiseFunDef :: FunDef Kernels -> m (FunDef Kernels)
optimiseFunDef FunDef Kernels
fundec = do
  Body Kernels
body' <- (VNameSource -> (Body Kernels, VNameSource)) -> m (Body Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body Kernels, VNameSource)) -> m (Body Kernels))
-> (VNameSource -> (Body Kernels, VNameSource)) -> m (Body Kernels)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Body Kernels)
-> VNameSource -> (Body Kernels, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Body Kernels)
 -> VNameSource -> (Body Kernels, VNameSource))
-> State VNameSource (Body Kernels)
-> VNameSource
-> (Body Kernels, VNameSource)
forall a b. (a -> b) -> a -> b
$
           ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
-> Scope Kernels -> State VNameSource (Body Kernels)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
m ([Param DeclType] -> Scope Kernels
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams (FunDef Kernels -> [FParam Kernels]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef Kernels
fundec))
  FunDef Kernels -> m (FunDef Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef Kernels
fundec { funDefBody :: Body Kernels
funDefBody = Body Kernels
body' }
  where m :: ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
m = Body Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
optimiseBody (Body Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels))
-> Body Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
forall a b. (a -> b) -> a -> b
$ FunDef Kernels -> Body Kernels
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef Kernels
fundec

type TileM = ReaderT (Scope Kernels) (State VNameSource)

optimiseBody :: Body Kernels -> TileM (Body Kernels)
optimiseBody :: Body Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
optimiseBody (Body () Stms Kernels
bnds Result
res) = Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
bnds) (ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
 -> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels))
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
forall a b. (a -> b) -> a -> b
$
  BodyAttr Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () (Stms Kernels -> Result -> Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Stms Kernels] -> Stms Kernels
forall a. Monoid a => [a] -> a
mconcat ([Stms Kernels] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stms Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stms Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStm (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
bnds)) ReaderT
  (Scope Kernels) (State VNameSource) (Result -> Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) Result
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope Kernels) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseStm :: Stm Kernels -> TileM (Stms Kernels)
optimiseStm :: Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStm (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux (Op (SegOp (SegMap lvl@SegThread{} space ts kbody)))) = do
  (Stms Kernels
host_stms, (SegLevel
lvl', SegSpace
space', KernelBody Kernels
kbody')) <- Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody Names
forall a. Monoid a => a
mempty VarianceTable
initial_variance SegLevel
lvl SegSpace
space [Type]
ts KernelBody Kernels
kbody
  Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels
host_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>
    Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl' SegSpace
space' [Type]
ts KernelBody Kernels
kbody')
  where initial_variance :: VarianceTable
initial_variance = (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> VarianceTable)
-> Map VName (NameInfo Any) -> VarianceTable
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space

optimiseStm (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux ExpT Kernels
e) =
  Stm Kernels -> Stms Kernels
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> ExpT Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise ExpT Kernels
e)
  where optimise :: Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope Kernels
-> Body Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
 -> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels))
-> (Body Kernels
    -> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels))
-> Body Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Body Kernels)
optimiseBody }

tileInKernelBody :: Names -> VarianceTable
                 -> SegLevel -> SegSpace -> [Type] -> KernelBody Kernels
                 -> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody :: Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
tileInKernelBody Names
branch_variant VarianceTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts KernelBody Kernels
kbody
  | Just Result
kbody_res <- (KernelResult -> Maybe SubExp) -> [KernelResult] -> Maybe Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> Maybe SubExp
isSimpleResult ([KernelResult] -> Maybe Result) -> [KernelResult] -> Maybe Result
forall a b. (a -> b) -> a -> b
$ KernelBody Kernels -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody Kernels
kbody = do
      Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled <-
        Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> Body Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody Names
branch_variant VarianceTable
initial_variance SegLevel
lvl SegSpace
initial_kspace [Type]
ts (Body Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> Body Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
        BodyAttr Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () (KernelBody Kernels -> Stms Kernels
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody Kernels
kbody) Result
kbody_res
      case Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled of
        Just (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) -> do
          ([KernelResult]
res', Stms Kernels
stms') <-
            Binder Kernels [KernelResult]
-> ReaderT
     (Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult]
 -> ReaderT
      (Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> ReaderT
     (Scope Kernels) (State VNameSource) ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> [VName] -> Binder Kernels [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tiling -> VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns Tiling
tiling) ([VName] -> Binder Kernels [KernelResult])
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels [KernelResult]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TiledBody
tiledBody PrivStms
forall a. Monoid a => a
mempty
          (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
host_stms, (Tiling -> SegLevel
tilingLevel Tiling
tiling,
                              Tiling -> SegSpace
tilingSpace Tiling
tiling,
                              BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms' [KernelResult]
res'))
        Maybe (Stms Kernels, Tiling, TiledBody)
Nothing ->
          (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody Kernels
kbody))
  | Bool
otherwise =
      (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
-> TileM (Stms Kernels, (SegLevel, SegSpace, KernelBody Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
forall a. Monoid a => a
mempty, (SegLevel
lvl, SegSpace
initial_kspace, KernelBody Kernels
kbody))
  where isSimpleResult :: KernelResult -> Maybe SubExp
isSimpleResult (Returns ResultManifest
_ SubExp
se) = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
se
        isSimpleResult KernelResult
_ = Maybe SubExp
forall a. Maybe a
Nothing

tileInBody :: Names -> VarianceTable
           -> SegLevel -> SegSpace -> [Type] -> Body Kernels
           -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody :: Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> Body Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody Names
branch_variant VarianceTable
initial_variance SegLevel
initial_lvl SegSpace
initial_space [Type]
res_ts (Body () Stms Kernels
initial_kstms Result
stms_res) =
  Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend Stms Kernels
forall a. Monoid a => a
mempty ([Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
initial_kstms
  where
    variance :: VarianceTable
variance = VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
initial_variance Stms Kernels
initial_kstms

    descend :: Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend Stms Kernels
_ [] =
      Maybe (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stms Kernels, Tiling, TiledBody)
forall a. Maybe a
Nothing

    descend Stms Kernels
prestms (Stm Kernels
stm_to_tile:[Stm Kernels]
poststms)

      -- 1D tiling of redomap.
      | (VName
gtid, SubExp
kdim) : [(VName, SubExp)]
top_space_rev <- [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
        Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form) <- Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm_to_tile,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
nameIn VName
gtid (Names -> Bool) -> (VName -> Names) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                   (VName -> VarianceTable -> Names)
-> VarianceTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) VarianceTable
variance) [VName]
arrs,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
gtid VName -> Names -> Bool
`nameIn` Names
branch_variant,
        (Stms Kernels
prestms', Stms Kernels
poststms') <-
          VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms),
        Names
used <- Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =

          (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ((Stms Kernels, Tiling, TiledBody)
    -> (Stms Kernels, Tiling, TiledBody))
-> (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms' Names
used ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          DoTiling VName SubExp
-> SegLevel
-> [Type]
-> Pattern Kernels
-> VName
-> SubExp
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [(VName, [Int])]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [(VName, [Int])]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileGeneric ([(VName, SubExp)] -> DoTiling VName SubExp
tiling1d ([(VName, SubExp)] -> DoTiling VName SubExp)
-> [(VName, SubExp)] -> DoTiling VName SubExp
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
top_space_rev)
          SegLevel
initial_lvl [Type]
res_ts (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
          VName
gtid SubExp
kdim
          SubExp
w (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form ([VName] -> [[Int]] -> [(VName, [Int])]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs ([[Int]] -> [(VName, [Int])]) -> [[Int]] -> [(VName, [Int])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[Int]]
forall a. a -> [a]
repeat [Int
0]) Stms Kernels
poststms' Result
stms_res

      -- 2D tiling of redomap.
      | ([VName]
gtids, Result
kdims) <- [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], Result))
-> [(VName, SubExp)] -> ([VName], Result)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space,
        Just (SubExp
w, [VName]
arrs, (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form) <- Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm_to_tile,
        Just [[Int]]
inner_perm <- (VName -> Maybe [Int]) -> [VName] -> Maybe [[Int]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Names -> VarianceTable -> [VName] -> VName -> Maybe [Int]
invariantToOneOfTwoInnerDims Names
branch_variant VarianceTable
variance [VName]
gtids) [VName]
arrs,
        VName
gtid_y : VName
gtid_x : [VName]
top_gtids_rev <- [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
gtids,
        SubExp
kdim_y : SubExp
kdim_x : Result
top_kdims_rev <- Result -> Result
forall a. [a] -> [a]
reverse Result
kdims,
        (Stms Kernels
prestms', Stms Kernels
poststms') <-
          VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms),
        Names
used <- Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
stms_res =

          (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ((Stms Kernels, Tiling, TiledBody)
    -> (Stms Kernels, Tiling, TiledBody))
-> (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms' Names
used ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          DoTiling (VName, VName) (SubExp, SubExp)
-> SegLevel
-> [Type]
-> Pattern Kernels
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [(VName, [Int])]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall gtids kdims.
DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [(VName, [Int])]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileGeneric ([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d ([(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp))
-> [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
top_gtids_rev Result
top_kdims_rev)
          SegLevel
initial_lvl [Type]
res_ts (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile)
          (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y)
          SubExp
w (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form ([VName] -> [[Int]] -> [(VName, [Int])]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [[Int]]
inner_perm) Stms Kernels
poststms' Result
stms_res

      -- Tiling inside for-loop.
      | DoLoop [] [(FParam Kernels, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
bound []) Body Kernels
loopbody <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm_to_tile,
        (Stms Kernels
prestms', Stms Kernels
poststms') <-
          VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prestms Stm Kernels
stm_to_tile ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm Kernels]
poststms) = do

          let branch_variant' :: Names
branch_variant' =
                Names
branch_variant Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
                [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> VarianceTable -> Names)
-> VarianceTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) VarianceTable
variance)
                         (Names -> [VName]
namesToList (SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
bound)))
              merge_params :: [Param DeclType]
merge_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam Kernels, SubExp)]
merge

          Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled <-
            Scope Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (VName -> NameInfo Kernels -> Scope Kernels -> Scope Kernels
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo Kernels
forall lore. IntType -> NameInfo lore
IndexInfo IntType
it) (Scope Kernels -> Scope Kernels) -> Scope Kernels -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param DeclType] -> Scope Kernels
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams [Param DeclType]
merge_params) (TileM (Maybe (Stms Kernels, Tiling, TiledBody))
 -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
            Names
-> VarianceTable
-> SegLevel
-> SegSpace
-> [Type]
-> Body Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
tileInBody Names
branch_variant' VarianceTable
variance SegLevel
initial_lvl SegSpace
initial_space
            ((Param DeclType -> Type) -> [Param DeclType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> Type
forall attr. Typed attr => Param attr -> Type
paramType [Param DeclType]
merge_params) (Body Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> Body Kernels -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Body Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms Body Kernels
loopbody) (Body Kernels -> Result
forall lore. BodyT lore -> Result
bodyResult Body Kernels
loopbody)

          case Maybe (Stms Kernels, Tiling, TiledBody)
maybe_tiled of
            Maybe (Stms Kernels, Tiling, TiledBody)
Nothing -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next
            Just (Stms Kernels, Tiling, TiledBody)
tiled ->
              (Stms Kernels, Tiling, TiledBody)
-> Maybe (Stms Kernels, Tiling, TiledBody)
forall a. a -> Maybe a
Just ((Stms Kernels, Tiling, TiledBody)
 -> Maybe (Stms Kernels, Tiling, TiledBody))
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> [Type]
-> Pattern Kernels
-> StmAux (ExpAttr Kernels)
-> [(FParam Kernels, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileDoLoop SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms'
              (Body Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Body Kernels
loopbody Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param DeclType, SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param DeclType, SubExp)]
[(FParam Kernels, SubExp)]
merge) (Stms Kernels, Tiling, TiledBody)
tiled
              [Type]
res_ts (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm_to_tile) (Stm Kernels -> StmAux (ExpAttr Kernels)
forall lore. Stm lore -> StmAux (ExpAttr lore)
stmAux Stm Kernels
stm_to_tile)
              [(FParam Kernels, SubExp)]
merge VName
i IntType
it SubExp
bound Stms Kernels
poststms' Result
stms_res

      | Bool
otherwise = TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next

      where next :: TileM (Maybe (Stms Kernels, Tiling, TiledBody))
next = Scope Kernels
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stm Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm Kernels
stm_to_tile) (TileM (Maybe (Stms Kernels, Tiling, TiledBody))
 -> TileM (Maybe (Stms Kernels, Tiling, TiledBody)))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
-> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
forall a b. (a -> b) -> a -> b
$
                   Stms Kernels
-> [Stm Kernels] -> TileM (Maybe (Stms Kernels, Tiling, TiledBody))
descend (Stms Kernels
prestms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stm_to_tile) [Stm Kernels]
poststms

-- | Move statements from prelude to postlude if they are not used in
-- the tiled statement anyway.
preludeToPostlude :: VarianceTable
                  -> Stms Kernels -> Stm Kernels -> Stms Kernels
                  -> (Stms Kernels, Stms Kernels)
preludeToPostlude :: VarianceTable
-> Stms Kernels
-> Stm Kernels
-> Stms Kernels
-> (Stms Kernels, Stms Kernels)
preludeToPostlude VarianceTable
variance Stms Kernels
prelude Stm Kernels
stm_to_tile Stms Kernels
postlude =
  (Stms Kernels
prelude_used, Stms Kernels
prelude_not_used Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
postlude)
  where used_in_tiled :: Names
used_in_tiled = Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
stm_to_tile

        used_in_stm_variant :: Names
used_in_stm_variant =
          (Names
used_in_tiledNames -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>) (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$ [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
          (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> VarianceTable -> Names)
-> VarianceTable -> VName -> Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty) VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$
          Names -> [VName]
namesToList Names
used_in_tiled

        used :: Stm Kernels -> Bool
used Stm Kernels
stm = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
used_in_stm_variant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
                   PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm

        (Stms Kernels
prelude_used, Stms Kernels
prelude_not_used) =
          (Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm Kernels -> Bool
used Stms Kernels
prelude

-- | Partition prelude statements preceding a tiled loop (or something
-- containing a tiled loop) into three categories:
--
-- 1) Group-level statements that are invariant to the threads in the group.
--
-- 2) Thread-variant statements that should be computed once with a segmap_thread_scalar.
--
-- 3) Thread-variant statements that should be recomputed whenever
-- they are needed.
--
-- The third category duplicates computation, so we only want to do it
-- when absolutely necessary.  Currently, this is necessary for
-- results that are views of an array (slicing, rotate, etc), because
-- these cannot be efficiently represented by a scalar segmap (they'll
-- be manifested in memory).
partitionPrelude :: VarianceTable -> Stms Kernels -> Names
                 -> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude :: VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude VarianceTable
variance Stms Kernels
prestms Names
tiled_kdims =
  (Stms Kernels
invariant_prestms, Stms Kernels
precomputed_variant_prestms, Stms Kernels
recomputed_variant_prestms)
  where
    invariantTo :: Names -> Stm Kernels -> Bool
invariantTo Names
names Stm Kernels
stm =
      case PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm) of
        [] -> Bool
True -- Does not matter.
        VName
v:[VName]
_ -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
names) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
               Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance
    (Stms Kernels
invariant_prestms, Stms Kernels
variant_prestms) =
      (Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition (Names -> Stm Kernels -> Bool
invariantTo Names
tiled_kdims) Stms Kernels
prestms

    mustBeInlinedExp :: ExpT lore -> Bool
mustBeInlinedExp (BasicOp (Index VName
_ Slice SubExp
slice)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
    mustBeInlinedExp (BasicOp Rotate{}) = Bool
True
    mustBeInlinedExp (BasicOp Rearrange{}) = Bool
True
    mustBeInlinedExp (BasicOp Reshape{}) = Bool
True
    mustBeInlinedExp ExpT lore
_ = Bool
False
    mustBeInlined :: Stm lore -> Bool
mustBeInlined = ExpT lore -> Bool
forall lore. ExpT lore -> Bool
mustBeInlinedExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp

    must_be_inlined :: Names
must_be_inlined = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm Kernels -> [VName]) -> [Stm Kernels] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm Kernels -> PatternT Type) -> Stm Kernels -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm Kernels] -> [VName]) -> [Stm Kernels] -> [VName]
forall a b. (a -> b) -> a -> b
$
                      Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms Kernels -> [Stm Kernels]) -> Stms Kernels -> [Stm Kernels]
forall a b. (a -> b) -> a -> b
$ (Stm Kernels -> Bool) -> Stms Kernels -> Stms Kernels
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter Stm Kernels -> Bool
forall lore. Stm lore -> Bool
mustBeInlined Stms Kernels
variant_prestms
    recompute :: Stm Kernels -> Bool
recompute Stm Kernels
stm =
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
must_be_inlined) (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
stm)) Bool -> Bool -> Bool
||
      Bool -> Bool
not (Names -> Stm Kernels -> Bool
invariantTo Names
must_be_inlined Stm Kernels
stm)
    (Stms Kernels
recomputed_variant_prestms, Stms Kernels
precomputed_variant_prestms) =
      (Stm Kernels -> Bool)
-> Stms Kernels -> (Stms Kernels, Stms Kernels)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition Stm Kernels -> Bool
recompute Stms Kernels
variant_prestms

injectPrelude :: SegSpace -> VarianceTable
              -> Stms Kernels -> Names
              -> (Stms Kernels, Tiling, TiledBody)
              -> (Stms Kernels, Tiling, TiledBody)
injectPrelude :: SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> (Stms Kernels, Tiling, TiledBody)
injectPrelude SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms Names
used (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) =
  (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody')
  where tiled_kdims :: Names
tiled_kdims = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$
                      ((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling)) ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
                      SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space

        tiledBody' :: TiledBody
tiledBody' PrivStms
privstms = do
          let (Stms Kernels
invariant_prestms,
               Stms Kernels
precomputed_variant_prestms,
               Stms Kernels
recomputed_variant_prestms) =
                VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude VarianceTable
variance Stms Kernels
prestms Names
tiled_kdims

          Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
invariant_prestms

          let live_set :: [VName]
live_set = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Names -> Names
forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
precomputed_variant_prestms (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
                         Names
used Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stms Kernels
recomputed_variant_prestms
          [VName]
prelude_arrs <- Stms Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
precomputed_variant_prestms (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
                          Tiling
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling Stms Kernels
precomputed_variant_prestms [VName]
live_set

          let prelude_privstms :: PrivStms
prelude_privstms =
                Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
                [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set

          TiledBody
tiledBody (PrivStms
prelude_privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
privstms)

tileDoLoop :: SegSpace -> VarianceTable
           -> Stms Kernels -> Names
           -> (Stms Kernels, Tiling, TiledBody)
           -> [Type] -> Pattern Kernels -> StmAux (ExpAttr Kernels)
           -> [(FParam Kernels, SubExp)] -> VName -> IntType -> SubExp
           -> Stms Kernels -> Result
           -> TileM (Stms Kernels, Tiling, TiledBody)
tileDoLoop :: SegSpace
-> VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Tiling, TiledBody)
-> [Type]
-> Pattern Kernels
-> StmAux (ExpAttr Kernels)
-> [(FParam Kernels, SubExp)]
-> VName
-> IntType
-> SubExp
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileDoLoop SegSpace
initial_space VarianceTable
variance Stms Kernels
prestms Names
used_in_body (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody) [Type]
res_ts Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux [(FParam Kernels, SubExp)]
merge VName
i IntType
it SubExp
bound Stms Kernels
poststms Result
poststms_res = do

  let (Stms Kernels
invariant_prestms,
       Stms Kernels
precomputed_variant_prestms,
       Stms Kernels
recomputed_variant_prestms) =
        VarianceTable
-> Stms Kernels
-> Names
-> (Stms Kernels, Stms Kernels, Stms Kernels)
partitionPrelude VarianceTable
variance Stms Kernels
prestms Names
tiled_kdims

  let ([Param DeclType]
mergeparams, Result
mergeinits) = [(Param DeclType, SubExp)] -> ([Param DeclType], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam Kernels, SubExp)]
merge

      -- Expand the loop merge parameters to be arrays.
      tileDim :: DeclType -> DeclType
tileDim DeclType
t = DeclType -> Shape -> Uniqueness -> DeclType
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf DeclType
t (Tiling -> Shape
tilingTileShape Tiling
tiling) (Uniqueness -> DeclType) -> Uniqueness -> DeclType
forall a b. (a -> b) -> a -> b
$ DeclType -> Uniqueness
forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness DeclType
t

      tiledBody' :: TiledBody
tiledBody' PrivStms
privstms = Stms Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
host_stms (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
        Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
invariant_prestms

        let live_set :: [VName]
live_set = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Names -> Names
forall a. FreeIn a => Stms Kernels -> a -> Names
liveSet Stms Kernels
precomputed_variant_prestms Names
used_in_body
        [VName]
prelude_arrs <- Stms Kernels
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
precomputed_variant_prestms (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
                        Tiling
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling Stms Kernels
precomputed_variant_prestms [VName]
live_set

        [Param DeclType]
mergeparams' <- [Param DeclType]
-> (Param DeclType
    -> BinderT Kernels (State VNameSource) (Param DeclType))
-> BinderT Kernels (State VNameSource) [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param DeclType]
mergeparams ((Param DeclType
  -> BinderT Kernels (State VNameSource) (Param DeclType))
 -> BinderT Kernels (State VNameSource) [Param DeclType])
-> (Param DeclType
    -> BinderT Kernels (State VNameSource) (Param DeclType))
-> BinderT Kernels (State VNameSource) [Param DeclType]
forall a b. (a -> b) -> a -> b
$ \(Param VName
pname DeclType
pt) ->
          VName -> DeclType -> Param DeclType
forall attr. VName -> attr -> Param attr
Param (VName -> DeclType -> Param DeclType)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) (DeclType -> Param DeclType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
pname String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group") BinderT Kernels (State VNameSource) (DeclType -> Param DeclType)
-> BinderT Kernels (State VNameSource) DeclType
-> BinderT Kernels (State VNameSource) (Param DeclType)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> DeclType -> BinderT Kernels (State VNameSource) DeclType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DeclType -> DeclType
tileDim DeclType
pt)

        let merge_ts :: [Type]
merge_ts = (Param DeclType -> Type) -> [Param DeclType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> Type
forall attr. Typed attr => Param attr -> Type
paramType [Param DeclType]
mergeparams

        let inloop_privstms :: PrivStms
inloop_privstms =
              Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
recomputed_variant_prestms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$
              [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prelude_arrs [VName]
live_set

        Result
mergeinit' <-
          ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ Certificates
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr Kernels)
aux) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
          Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"tiled_loopinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
          \PrimExp VName
in_bounds Slice SubExp
slice ->
            ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"loopinit" PrimExp VName
in_bounds [Type]
merge_ts (BinderT Kernels (State VNameSource) Result
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
            Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
inloop_privstms
            Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
            Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
mergeinits

        let merge' :: [(Param DeclType, SubExp)]
merge' = [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mergeparams' Result
mergeinit'

        let indexMergeParams :: ReadPrelude
indexMergeParams Slice SubExp
slice =
              Scope Kernels
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope Kernels
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams [Param DeclType]
mergeparams') (BinderT Kernels (State VNameSource) ()
 -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
              [(Param DeclType, Param DeclType)]
-> ((Param DeclType, Param DeclType)
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param DeclType]
-> [Param DeclType] -> [(Param DeclType, Param DeclType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mergeparams [Param DeclType]
mergeparams') (((Param DeclType, Param DeclType)
  -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) ())
-> ((Param DeclType, Param DeclType)
    -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param DeclType
to, Param DeclType
from) ->
              [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [Param DeclType -> VName
forall attr. Param attr -> VName
paramName Param DeclType
to] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index (Param DeclType -> VName
forall attr. Param attr -> VName
paramName Param DeclType
from) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              Type -> Slice SubExp -> Slice SubExp
fullSlice (Param DeclType -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param DeclType
from) Slice SubExp
slice

        Body Kernels
loopbody' <- Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Result -> Body Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body Kernels)
-> ([VName] -> Result) -> [VName] -> Body Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Body Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                     TiledBody
tiledBody (PrivStms
privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> PrivStms
inloop_privstms PrivStms -> PrivStms -> PrivStms
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
forall a. Monoid a => a
mempty ReadPrelude
indexMergeParams)
        [VName]
accs' <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"tiled_inside_loop" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
                 [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> Body Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam Kernels, SubExp)]
merge' (VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound []) Body Kernels
loopbody'

        Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts

  (Stms Kernels, Tiling, TiledBody)
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
host_stms, Tiling
tiling, TiledBody
tiledBody')

  where tiled_kdims :: Names
tiled_kdims = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$
                      ((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName, SubExp) -> [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` SegSpace -> [(VName, SubExp)]
unSegSpace (Tiling -> SegSpace
tilingSpace Tiling
tiling)) ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
                      SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
initial_space

doPrelude :: Tiling -> Stms Kernels -> [VName] -> Binder Kernels [VName]
doPrelude :: Tiling
-> Stms Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
doPrelude Tiling
tiling Stms Kernels
prestms [VName]
prestms_live =
  -- Create a SegMap that takes care of the prelude for every thread.
  Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"prelude" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
  \PrimExp VName
in_bounds Slice SubExp
_slice -> do
    [Type]
ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
prestms_live
    ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"pre" (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
in_bounds)
      (do Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
prestms
          Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
prestms_live)
      ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
ts)

liveSet :: FreeIn a => Stms Kernels -> a -> Names
liveSet :: Stms Kernels -> a -> Names
liveSet Stms Kernels
stms a
after =
  [VName] -> Names
namesFromList ((Stm Kernels -> [VName]) -> Stms Kernels -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm Kernels -> PatternT Type) -> Stm Kernels -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) Stms Kernels
stms) Names -> Names -> Names
`namesIntersection`
  a -> Names
forall a. FreeIn a => a -> Names
freeIn a
after

tileable :: Stm Kernels
         -> Maybe (SubExp, [VName],
                   (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
tileable :: Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
tileable Stm Kernels
stm
  | Op (OtherOp (Screma w form arrs)) <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm,
    Just ([Reduce Kernels]
reds, Lambda Kernels
map_lam) <- ScremaForm Kernels -> Maybe ([Reduce Kernels], Lambda Kernels)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm Kernels
form,
    Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
red_nes <- [Reduce Kernels] -> Reduce Kernels
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce Kernels]
reds,
    Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam, -- No mapout arrays.
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs,
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam =
      (SubExp, [VName],
 (Commutativity, Lambda Kernels, Result, Lambda Kernels))
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, Result, Lambda Kernels))
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda Kernels
red_lam, Result
red_nes, Lambda Kernels
map_lam))
  | Bool
otherwise =
      Maybe
  (SubExp, [VName],
   (Commutativity, Lambda Kernels, Result, Lambda Kernels))
forall a. Maybe a
Nothing

-- | Statements that we insert directly into every thread-private
-- SegMaps.  This is for things that cannot efficiently be computed
-- once in advance in the prelude SegMap, primarily (exclusively?)
-- array slicing operations.
data PrivStms = PrivStms (Stms Kernels) ReadPrelude

privStms :: Stms Kernels -> PrivStms
privStms :: Stms Kernels -> PrivStms
privStms Stms Kernels
stms = Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
stms (ReadPrelude -> PrivStms) -> ReadPrelude -> PrivStms
forall a b. (a -> b) -> a -> b
$ BinderT Kernels (State VNameSource) () -> ReadPrelude
forall a b. a -> b -> a
const (BinderT Kernels (State VNameSource) () -> ReadPrelude)
-> BinderT Kernels (State VNameSource) () -> ReadPrelude
forall a b. (a -> b) -> a -> b
$ () -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

addPrivStms :: Slice SubExp -> PrivStms -> Binder Kernels ()
addPrivStms :: Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
local_slice (PrivStms Stms Kernels
stms ReadPrelude
readPrelude) = do
  ReadPrelude
readPrelude Slice SubExp
local_slice
  Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
stms

instance Semigroup PrivStms where
  PrivStms Stms Kernels
stms_x ReadPrelude
readPrelude_x <> :: PrivStms -> PrivStms -> PrivStms
<> PrivStms Stms Kernels
stms_y ReadPrelude
readPrelude_y =
    Stms Kernels -> ReadPrelude -> PrivStms
PrivStms Stms Kernels
stms_z ReadPrelude
readPrelude_z
    where stms_z :: Stms Kernels
stms_z = Stms Kernels
stms_x Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
stms_y
          readPrelude_z :: ReadPrelude
readPrelude_z Slice SubExp
slice = ReadPrelude
readPrelude_x Slice SubExp
slice BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ReadPrelude
readPrelude_y Slice SubExp
slice

instance Monoid PrivStms where
  mempty :: PrivStms
mempty = Stms Kernels -> PrivStms
privStms Stms Kernels
forall a. Monoid a => a
mempty

type ReadPrelude = Slice SubExp -> Binder Kernels ()

-- | Information about a loop that has been tiled inside a kernel, as
-- well as the kinds of changes that we would then like to perform on
-- the kernel.
data Tiling =
  Tiling
  { Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap :: String -> SegLevel -> ResultManifest
                 -> (PrimExp VName -> Slice SubExp -> Binder Kernels [SubExp])
                 -> Binder Kernels [VName]
    -- The boolean PrimExp indicates whether they are in-bounds.

  , Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
tilingReadTile :: TileKind -> PrivStms
                   -> SubExp -> [(VName, [Int])]
                   -> Binder Kernels [VName]

  , Tiling
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile :: PrivStms
                      -> Commutativity -> Lambda Kernels -> Lambda Kernels
                      -> [(VName, [Int])] -> [VName]
                      -> Binder Kernels [VName]

  , Tiling
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> SubExp
-> [VName]
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile :: PrivStms
                              -> Commutativity -> Lambda Kernels -> Lambda Kernels
                              -> SubExp -> [VName] -> SubExp
                              -> [(VName, [Int])]
                              -> Binder Kernels [VName]

  , Tiling -> VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns :: VName -> Binder Kernels KernelResult

  , Tiling -> SegSpace
tilingSpace :: SegSpace

  , Tiling -> Shape
tilingTileShape :: Shape

  , Tiling -> SegLevel
tilingLevel :: SegLevel

  , Tiling -> SubExp
tilingNumWholeTiles :: SubExp
  }

type DoTiling gtids kdims =
  SegLevel -> gtids -> kdims -> SubExp -> Binder Kernels Tiling

scalarLevel :: Tiling -> SegLevel
scalarLevel :: Tiling -> SegLevel
scalarLevel Tiling
tiling =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt
  where lvl :: SegLevel
lvl = Tiling -> SegLevel
tilingLevel Tiling
tiling

protectOutOfBounds :: String -> PrimExp VName -> [Type] -> Binder Kernels [SubExp]
                   -> Binder Kernels [VName]
protectOutOfBounds :: String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
desc PrimExp VName
in_bounds [Type]
ts BinderT Kernels (State VNameSource) Result
m =
  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
in_bounds) (Result -> Body Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body Kernels)
-> BinderT Kernels (State VNameSource) Result
-> Binder Kernels (Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT Kernels (State VNameSource) Result
m) ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
ts)

postludeGeneric :: Tiling -> PrivStms
                -> Pattern Kernels -> [VName]
                -> Stms Kernels -> Result -> [Type]
                -> Binder Kernels [VName]
postludeGeneric :: Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts =
  Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"thread_res" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds Slice SubExp
slice -> do
    -- Read our per-thread result from the tiled loop.
    [(VName, VName)]
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern Kernels
pat) [VName]
accs') (((VName, VName) -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) ())
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(VName
us, VName
everyone) ->
      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
us] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
everyone Slice SubExp
slice

    if Stms Kernels
poststms Stms Kernels -> Stms Kernels -> Bool
forall a. Eq a => a -> a -> Bool
== Stms Kernels
forall a. Monoid a => a
mempty
      then do -- The privstms may still be necessary for the result.
      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
      Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
poststms_res

      else
      ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"postlude" PrimExp VName
in_bounds [Type]
res_ts (BinderT Kernels (State VNameSource) Result
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
      Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
      Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
poststms
      Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
poststms_res

type TiledBody = PrivStms -> Binder Kernels [VName]

tileGeneric :: DoTiling gtids kdims
            -> SegLevel
            -> [Type]
            -> Pattern Kernels
            -> gtids
            -> kdims
            -> SubExp
            -> (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels)
            -> [(VName, [Int])]
            -> Stms Kernels -> Result
            -> TileM (Stms Kernels, Tiling, TiledBody)
tileGeneric :: DoTiling gtids kdims
-> SegLevel
-> [Type]
-> Pattern Kernels
-> gtids
-> kdims
-> SubExp
-> (Commutativity, Lambda Kernels, Result, Lambda Kernels)
-> [(VName, [Int])]
-> Stms Kernels
-> Result
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
tileGeneric DoTiling gtids kdims
doTiling SegLevel
initial_lvl [Type]
res_ts Pattern Kernels
pat gtids
gtids kdims
kdims SubExp
w (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form [(VName, [Int])]
arrs_and_perms Stms Kernels
poststms Result
poststms_res = do

  (Tiling
tiling, Stms Kernels
tiling_stms) <- Binder Kernels Tiling
-> ReaderT
     (Scope Kernels) (State VNameSource) (Tiling, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels Tiling
 -> ReaderT
      (Scope Kernels) (State VNameSource) (Tiling, Stms Kernels))
-> Binder Kernels Tiling
-> ReaderT
     (Scope Kernels) (State VNameSource) (Tiling, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ DoTiling gtids kdims
doTiling SegLevel
initial_lvl gtids
gtids kdims
kdims SubExp
w

  (Stms Kernels, Tiling, TiledBody)
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     (Stms Kernels, Tiling, TiledBody)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
tiling_stms, Tiling
tiling, Tiling -> TiledBody
tiledBody Tiling
tiling)

  where
    (Commutativity
red_comm, Lambda Kernels
red_lam, Result
red_nes, Lambda Kernels
map_lam) = (Commutativity, Lambda Kernels, Result, Lambda Kernels)
form

    tiledBody :: Tiling -> PrivStms -> Binder Kernels [VName]
    tiledBody :: Tiling -> TiledBody
tiledBody Tiling
tiling PrivStms
privstms = do
      let num_whole_tiles :: SubExp
num_whole_tiles = Tiling -> SubExp
tilingNumWholeTiles Tiling
tiling
          tile_shape :: Shape
tile_shape = Tiling -> Shape
tilingTileShape Tiling
tiling

      -- We don't use a Replicate here, because we want to enforce a
      -- scalar memory space.
      [VName]
mergeinits <- Tiling
-> String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap Tiling
tiling String
"mergeinit" (Tiling -> SegLevel
scalarLevel Tiling
tiling) ResultManifest
ResultPrivate ((PrimExp VName
  -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \PrimExp VName
in_bounds Slice SubExp
slice ->
        -- Constant neutral elements (a common case) do not need protection from OOB.
        if Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
red_nes Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
          then Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
red_nes
          else ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ String
-> PrimExp VName
-> [Type]
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
protectOutOfBounds String
"neutral" PrimExp VName
in_bounds (Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam) (BinderT Kernels (State VNameSource) Result
 -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) Result
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
          Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms Slice SubExp
slice PrivStms
privstms
          Result -> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
red_nes

      [(Param DeclType, SubExp)]
merge <- [(Param Type, VName)]
-> ((Param Type, VName)
    -> BinderT Kernels (State VNameSource) (Param DeclType, SubExp))
-> BinderT Kernels (State VNameSource) [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam) [VName]
mergeinits) (((Param Type, VName)
  -> BinderT Kernels (State VNameSource) (Param DeclType, SubExp))
 -> BinderT Kernels (State VNameSource) [(Param DeclType, SubExp)])
-> ((Param Type, VName)
    -> BinderT Kernels (State VNameSource) (Param DeclType, SubExp))
-> BinderT Kernels (State VNameSource) [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
mergeinit) ->
        (,) (Param DeclType -> SubExp -> (Param DeclType, SubExp))
-> BinderT Kernels (State VNameSource) (Param DeclType)
-> BinderT
     Kernels (State VNameSource) (SubExp -> (Param DeclType, SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        String
-> DeclType -> BinderT Kernels (State VNameSource) (Param DeclType)
forall (m :: * -> *) attr.
MonadFreshNames m =>
String -> attr -> m (Param attr)
newParam (VName -> String
baseString (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
p) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_merge")
        (Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param Type
p Type -> Shape -> Type
`arrayOfShape` Shape
tile_shape Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique) BinderT
  Kernels (State VNameSource) (SubExp -> (Param DeclType, SubExp))
-> BinderT Kernels (State VNameSource) SubExp
-> BinderT Kernels (State VNameSource) (Param DeclType, SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
        SubExp -> BinderT Kernels (State VNameSource) SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
mergeinit)

      VName
tile_id <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tile_id"
      let loopform :: LoopForm Kernels
loopform = VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
tile_id IntType
Int32 SubExp
num_whole_tiles []
      Body Kernels
loopbody <- Body Kernels -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> Binder Kernels (Body Kernels))
-> (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels)
-> Binder Kernels (Body Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$ LoopForm Kernels
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm Kernels
loopform (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$
                  Scope Kernels
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope Kernels
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope Kernels)
-> [Param DeclType] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$ do

        -- Collectively read a tile.
        [VName]
tile <- Tiling
-> TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
tilingReadTile Tiling
tiling TileKind
TileFull PrivStms
privstms (VName -> SubExp
Var VName
tile_id) [(VName, [Int])]
arrs_and_perms

        -- Now each thread performs a traversal of the tile and
        -- updates its accumulator.
        Result -> Body Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body Kernels)
-> ([VName] -> Result) -> [VName] -> Body Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Body Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          Tiling
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile Tiling
tiling PrivStms
privstms
          Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam
          ([VName] -> [[Int]] -> [(VName, [Int])]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tile (((VName, [Int]) -> [Int]) -> [(VName, [Int])] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
map (VName, [Int]) -> [Int]
forall a b. (a, b) -> b
snd [(VName, [Int])]
arrs_and_perms)) (((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall attr. Param attr -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge)

      [VName]
accs <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"accs" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> Body Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam Kernels, SubExp)]
merge LoopForm Kernels
loopform Body Kernels
loopbody

      -- We possibly have to traverse a residual tile.
      Lambda Kernels
red_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
red_lam
      Lambda Kernels
map_lam' <- Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda Kernels
map_lam
      [VName]
accs' <- Tiling
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> SubExp
-> [VName]
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile Tiling
tiling PrivStms
privstms
               Commutativity
red_comm Lambda Kernels
red_lam' Lambda Kernels
map_lam'
               SubExp
num_whole_tiles [VName]
accs SubExp
w [(VName, [Int])]
arrs_and_perms

      -- Create a SegMap that takes care of the postlude for every thread.
      Tiling
-> PrivStms
-> Pattern Kernels
-> [VName]
-> Stms Kernels
-> Result
-> [Type]
-> BinderT Kernels (State VNameSource) [VName]
postludeGeneric Tiling
tiling PrivStms
privstms Pattern Kernels
pat [VName]
accs' Stms Kernels
poststms Result
poststms_res [Type]
res_ts

data TileKind = TilePartial | TileFull

mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude
mkReadPreludeValues [VName]
prestms_live_arrs [VName]
prestms_live Slice SubExp
slice =
  ([()] -> ())
-> BinderT Kernels (State VNameSource) [()]
-> BinderT Kernels (State VNameSource) ()
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [()] -> ()
forall a. Monoid a => [a] -> a
mconcat (BinderT Kernels (State VNameSource) [()]
 -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ [(VName, VName)]
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
prestms_live_arrs [VName]
prestms_live) (((VName, VName) -> BinderT Kernels (State VNameSource) ())
 -> BinderT Kernels (State VNameSource) [()])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) [()]
forall a b. (a -> b) -> a -> b
$ \(VName
arr, VName
v) -> do
  Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
v] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t Slice SubExp
slice

tileReturns :: [(VName, SubExp)] -> [(SubExp, SubExp)] -> VName -> Binder Kernels KernelResult
tileReturns :: [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp, SubExp)]
dims VName
arr = do
  let unit_dims :: Result
unit_dims = Int -> SubExp -> Result
forall a. Int -> a -> [a]
replicate ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
dims_on_top) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1)
  VName
arr' <- if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top then VName -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
          else do Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
                  let new_shape :: Result
new_shape = Result
unit_dims Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
                  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr) (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew Result
new_shape) VName
arr
  let tile_dims :: [(SubExp, SubExp)]
tile_dims = Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top) Result
unit_dims [(SubExp, SubExp)] -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(SubExp, SubExp)]
dims
  KernelResult -> BinderT Kernels (State VNameSource) KernelResult
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> BinderT Kernels (State VNameSource) KernelResult)
-> KernelResult -> BinderT Kernels (State VNameSource) KernelResult
forall a b. (a -> b) -> a -> b
$ [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns [(SubExp, SubExp)]
tile_dims VName
arr'

segMap1D :: String
         -> SegLevel -> ResultManifest
         -> (VName -> Binder Kernels [SubExp])
         -> Binder Kernels [VName]
segMap1D :: String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
desc SegLevel
lvl ResultManifest
manifest VName -> BinderT Kernels (State VNameSource) Result
f = do
  VName
ltid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid"
  VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid, Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)]

  (([Type]
ts, Result
res), Stms Kernels
stms) <- Binder Kernels ([Type], Result)
-> BinderT
     Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels ([Type], Result)
 -> BinderT
      Kernels (State VNameSource) (([Type], Result), Stms Kernels))
-> Binder Kernels ([Type], Result)
-> BinderT
     Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- VName -> BinderT Kernels (State VNameSource) Result
f VName
ltid
    [Type]
ts <- (SubExp -> BinderT Kernels (State VNameSource) Type)
-> Result -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType Result
res
    ([Type], Result) -> Binder Kernels ([Type], Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, Result
res)
  Body BodyAttr Kernels
_ Stms Kernels
stms' Result
res' <- Body Kernels -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> Binder Kernels (Body Kernels))
-> Body Kernels -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
stms Result
res

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
    SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts (KernelBody Kernels -> SegOp SegLevel Kernels)
-> KernelBody Kernels -> SegOp SegLevel Kernels
forall a b. (a -> b) -> a -> b
$ BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms' ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest) Result
res'

reconstructGtids1D :: Count GroupSize SubExp -> VName -> VName -> VName
                   -> Binder Kernels ()
reconstructGtids1D :: Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid  =
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
gtid] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gid PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
*
           PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size) PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
           VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
ltid PrimType
int32)

readTile1D :: SubExp -> VName -> VName
           -> Count NumGroups SubExp -> Count GroupSize SubExp
           -> TileKind -> PrivStms
           -> SubExp
           -> [(VName, [Int])]
           -> Binder Kernels [VName]
readTile1D :: SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
readTile1D
  SubExp
tile_size VName
gid VName
gtid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size
  TileKind
kind PrivStms
privstms SubExp
tile_id [(VName, [Int])]
arrs_and_perms =

  String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
"full_tile" (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt) ResultManifest
ResultNoSimplify ((VName -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
    SubExp
j <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j" (ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
         PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
tile_id PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
*
                PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
tile_size PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
                VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
ltid PrimType
int32)

    Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
    Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms

    let arrs :: [VName]
arrs = ((VName, [Int]) -> VName) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, [Int]) -> VName
forall a b. (a, b) -> a
fst [(VName, [Int])]
arrs_and_perms
    [Type]
arr_ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrs
    let tile_ts :: [Type]
tile_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
arr_ts
        w :: SubExp
w = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
arr_ts

    let readTileElem :: VName -> BinderT Kernels (State VNameSource) VName
readTileElem VName
arr =
          -- No need for fullSlice because we are tiling only prims.
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"tile_elem" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
j]
    ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
      case TileKind
kind of
        TileKind
TilePartial ->
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"pre" (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
j PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<.
                                   PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
w)
          (Result -> Body Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body Kernels)
-> BinderT Kernels (State VNameSource) Result
-> Binder Kernels (Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> BinderT Kernels (State VNameSource) SubExp)
-> [VName] -> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> SubExp)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (BinderT Kernels (State VNameSource) VName
 -> BinderT Kernels (State VNameSource) SubExp)
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> VName
-> BinderT Kernels (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> BinderT Kernels (State VNameSource) VName
readTileElem) [VName]
arrs)
          ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([BinderT
    Kernels
    (State VNameSource)
    (Exp (Lore (BinderT Kernels (State VNameSource))))]
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> [BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Type -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [Type] -> [BinderT Kernels (State VNameSource) (ExpT Kernels)]
forall a b. (a -> b) -> [a] -> [b]
map Type -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank [Type]
tile_ts)
        TileKind
TileFull ->
          (VName -> BinderT Kernels (State VNameSource) VName)
-> [VName] -> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) VName
readTileElem [VName]
arrs

processTile1D :: VName -> VName -> SubExp -> SubExp
              -> Count NumGroups SubExp -> Count GroupSize SubExp
              -> PrivStms
              -> Commutativity -> Lambda Kernels -> Lambda Kernels
              -> [(VName, [Int])] -> [VName]
              -> Binder Kernels [VName]
processTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
processTile1D
  VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size
  PrivStms
privstms
  Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [(VName, [Int])]
tiles_and_perm [VName]
accs = do

  let tile :: [VName]
tile = ((VName, [Int]) -> VName) -> [(VName, [Int])] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, [Int]) -> VName
forall a b. (a, b) -> a
fst [(VName, [Int])]
tiles_and_perm

  String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
"acc" (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirt) ResultManifest
ResultPrivate ((VName -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do

    Count GroupSize SubExp
-> VName
-> VName
-> VName
-> BinderT Kernels (State VNameSource) ()
reconstructGtids1D Count GroupSize SubExp
group_size VName
gtid VName
gid VName
ltid
    Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid] PrivStms
privstms

    -- We replace the neutral elements with the accumulators (this is
    -- OK because the parallel semantics are not used after this
    -- point).
    Result
thread_accs <- [VName]
-> (VName -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> BinderT Kernels (State VNameSource) SubExp)
 -> BinderT Kernels (State VNameSource) Result)
-> (VName -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"acc" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid]
    let form' :: ScremaForm Kernels
form' = [Reduce Kernels] -> Lambda Kernels -> ScremaForm Kernels
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Commutativity -> Lambda Kernels -> Result -> Reduce Kernels
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
thread_accs] Lambda Kernels
map_lam

    ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc" (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
kdim)
      ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm Kernels -> [VName] -> SOAC Kernels
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
tile_size ScremaForm Kernels
form' [VName]
tile])
      (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
thread_accs)

processResidualTile1D :: VName -> VName -> SubExp -> SubExp
                      -> Count NumGroups SubExp -> Count GroupSize SubExp -> PrivStms
                      -> Commutativity -> Lambda Kernels -> Lambda Kernels
                      -> SubExp -> [VName] -> SubExp -> [(VName, [Int])]
                      -> Binder Kernels [VName]
processResidualTile1D :: VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> SubExp
-> [VName]
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile1D
  VName
gid VName
gtid SubExp
kdim SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam
  SubExp
num_whole_tiles [VName]
accs SubExp
w [(VName, [Int])]
arrs_and_perms = do
  -- The number of residual elements that are not covered by
  -- the whole tiles.
  SubExp
residual_input <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"residual_input" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SRem IntType
Int32) SubExp
w SubExp
tile_size

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc_after_residual" (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
residual_input PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. PrimExp VName
0)
    (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
    (SubExp -> Binder Kernels (Body Kernels)
nonemptyTile SubExp
residual_input)

  where
    nonemptyTile :: SubExp -> Binder Kernels (Body Kernels)
nonemptyTile SubExp
residual_input = Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$ do
      -- Collectively construct a tile.  Threads that are out-of-bounds
      -- provide a blank dummy value.
      [VName]
full_tile <- SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
readTile1D SubExp
tile_size VName
gid VName
gtid  Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size
                   TileKind
TilePartial PrivStms
privstms SubExp
num_whole_tiles [(VName, [Int])]
arrs_and_perms
      [VName]
tile <- [VName]
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
full_tile ((VName -> BinderT Kernels (State VNameSource) VName)
 -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
tile ->
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"partial_tile" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile
        [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1)]

      -- Now each thread performs a traversal of the tile and
      -- updates its accumulator.
      Result -> Body Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body Kernels)
-> ([VName] -> Result) -> [VName] -> Body Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Body Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
processTile1D
        VName
gid VName
gtid SubExp
kdim SubExp
residual_input Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size PrivStms
privstms
        Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam ([VName] -> [[Int]] -> [(VName, [Int])]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tile ([[Int]] -> [(VName, [Int])]) -> [[Int]] -> [(VName, [Int])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[Int]]
forall a. a -> [a]
repeat [Int
0]) [VName]
accs

tiling1d :: [(VName,SubExp)] -> DoTiling VName SubExp
tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp
tiling1d [(VName, SubExp)]
dims_on_top SegLevel
initial_lvl VName
gtid SubExp
kdim SubExp
w = do
  VName
gid <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid"
  VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"

  (SegLevel
lvl, SegSpace
space) <-
    if [(VName, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(VName, SubExp)]
dims_on_top
    then (SegLevel, SegSpace)
-> BinderT Kernels (State VNameSource) (SegLevel, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl) (SegVirt -> SegLevel) -> SegVirt -> SegLevel
forall a b. (a -> b) -> a -> b
$ SegLevel -> SegVirt
segVirt SegLevel
initial_lvl,
                 VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat [(VName
gid, Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
initial_lvl)])
    else do
      SubExp
group_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"computed_group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMin IntType
Int32) (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
initial_lvl)) SubExp
kdim

      -- How many groups we need to exhaust the innermost dimension.
      SubExp
ldim <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"ldim" (ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
              IntType
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eDivRoundingUp IntType
Int32 (SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
kdim) (SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
group_size)

      SubExp
num_groups <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"computed_num_groups" (ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                    BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
ldim (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)

      (SegLevel, SegSpace)
-> BinderT Kernels (State VNameSource) (SegLevel, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt,
              VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid, SubExp
ldim)])
  let tile_size :: SubExp
tile_size = Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

  -- Number of whole tiles that fit in the input.
  SubExp
num_whole_tiles <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_whole_tiles" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SQuot IntType
Int32) SubExp
w SubExp
tile_size
  Tiling -> Binder Kernels Tiling
forall (m :: * -> *) a. Monad m => a -> m a
return Tiling :: (String
 -> SegLevel
 -> ResultManifest
 -> (PrimExp VName
     -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (TileKind
    -> PrivStms
    -> SubExp
    -> [(VName, [Int])]
    -> BinderT Kernels (State VNameSource) [VName])
-> (PrivStms
    -> Commutativity
    -> Lambda Kernels
    -> Lambda Kernels
    -> [(VName, [Int])]
    -> [VName]
    -> BinderT Kernels (State VNameSource) [VName])
-> (PrivStms
    -> Commutativity
    -> Lambda Kernels
    -> Lambda Kernels
    -> SubExp
    -> [VName]
    -> SubExp
    -> [(VName, [Int])]
    -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> SegSpace
-> Shape
-> SegLevel
-> SubExp
-> Tiling
Tiling
    { tilingSegMap :: String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap = \String
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f -> String
-> SegLevel
-> ResultManifest
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap1D String
desc SegLevel
lvl' ResultManifest
manifest ((VName -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
ltid -> do
        [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
gtid] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
          PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gid PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
* PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
tile_size PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
                 VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
ltid PrimType
int32)
        PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
kdim)
          [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid]

    , tilingReadTile :: TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
tilingReadTile =
        SubExp
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
readTile1D SubExp
tile_size VName
gid VName
gtid (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

    , tilingProcessTile :: PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile =
        VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
processTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

    , tilingProcessResidualTile :: PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> SubExp
-> [VName]
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile =
        VName
-> VName
-> SubExp
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> SubExp
-> [VName]
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile1D VName
gid VName
gtid SubExp
kdim SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

    , tilingTileReturns :: VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim, SubExp
tile_size)]

    , tilingTileShape :: Shape
tilingTileShape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size]
    , tilingNumWholeTiles :: SubExp
tilingNumWholeTiles = SubExp
num_whole_tiles
    , tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl
    , tilingSpace :: SegSpace
tilingSpace = SegSpace
space
    }

invariantToOneOfTwoInnerDims :: Names -> M.Map VName Names -> [VName] -> VName
                             -> Maybe [Int]
invariantToOneOfTwoInnerDims :: Names -> VarianceTable -> [VName] -> VName -> Maybe [Int]
invariantToOneOfTwoInnerDims Names
branch_variant VarianceTable
variance [VName]
dims VName
arr = do
  VName
j : VName
i : [VName]
_ <- [VName] -> Maybe [VName]
forall a. a -> Maybe a
Just ([VName] -> Maybe [VName]) -> [VName] -> Maybe [VName]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
dims
  let variant_to :: Names
variant_to = Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
arr VarianceTable
variance
      branch_invariant :: Bool
branch_invariant = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Bool
nameIn VName
j Names
branch_variant Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
i Names
branch_variant
  if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
i VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName
j VName -> Names -> Bool
`nameIn` Names
variant_to) then
    [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int
0,Int
1]
  else if Bool
branch_invariant Bool -> Bool -> Bool
&& VName
j VName -> Names -> Bool
`nameIn` Names
variant_to Bool -> Bool -> Bool
&& Bool -> Bool
not (VName
i VName -> Names -> Bool
`nameIn` Names
variant_to) then
    [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int
1,Int
0]
  else
    Maybe [Int]
forall a. Maybe a
Nothing

segMap2D :: String
         -> SegLevel -> ResultManifest -> (SubExp, SubExp)
         -> ((VName, VName) -> Binder Kernels [SubExp])
         -> Binder Kernels [VName]
segMap2D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_x, SubExp
dim_y) (VName, VName) -> BinderT Kernels (State VNameSource) Result
f = do
  VName
ltid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  VName
ltid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_x, SubExp
dim_x), (VName
ltid_y, SubExp
dim_y)]

  (([Type]
ts, Result
res), Stms Kernels
stms) <- Binder Kernels ([Type], Result)
-> BinderT
     Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels ([Type], Result)
 -> BinderT
      Kernels (State VNameSource) (([Type], Result), Stms Kernels))
-> Binder Kernels ([Type], Result)
-> BinderT
     Kernels (State VNameSource) (([Type], Result), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    Result
res <- (VName, VName) -> BinderT Kernels (State VNameSource) Result
f (VName
ltid_x, VName
ltid_y)
    [Type]
ts <- (SubExp -> BinderT Kernels (State VNameSource) Type)
-> Result -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType Result
res
    ([Type], Result) -> Binder Kernels ([Type], Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, Result
res)
  Body BodyAttr Kernels
_ Stms Kernels
stms' Result
res' <- Body Kernels -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> Binder Kernels (Body Kernels))
-> Body Kernels -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
stms Result
res

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) [VName])
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
    SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts (KernelBody Kernels -> SegOp SegLevel Kernels)
-> KernelBody Kernels -> SegOp SegLevel Kernels
forall a b. (a -> b) -> a -> b
$ BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms' ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest) Result
res'

-- Reconstruct the original gtids from group and local IDs.
reconstructGtids2D :: SubExp -> (VName, VName) -> (VName, VName) -> (VName, VName)
                   -> Binder Kernels ()
reconstructGtids2D :: SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y) = do
  -- Reconstruct the original gtids from gid_x/gid_y and ltid_x/ltid_y.
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
gtid_x] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gid_x PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
* PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
tile_size PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
           VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
ltid_x PrimType
int32)
  [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
gtid_y] (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gid_y PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
* PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
tile_size PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
            VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
ltid_y PrimType
int32)

readTile2D :: (SubExp, SubExp) -> (VName, VName) -> (VName, VName) -> SubExp
           -> Count NumGroups SubExp -> Count GroupSize SubExp
           -> TileKind -> PrivStms -> SubExp
           -> [(VName, [Int])]
           -> Binder Kernels [VName]
readTile2D :: (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size TileKind
kind PrivStms
privstms SubExp
tile_id [(VName, [Int])]
arrs_and_perms =
  String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D String
"full_tile" (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirtFull)
  ResultManifest
ResultNoSimplify (SubExp
tile_size, SubExp
tile_size) (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
    SubExp
i <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i" (ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
         PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
tile_id PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
*
                PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
tile_size PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
                VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
ltid_x PrimType
int32)
    SubExp
j <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j" (ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
         PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
tile_id PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
*
                PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
tile_size PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
                VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
ltid_y PrimType
int32)

    SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
    Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms

    let ([VName]
arrs, [[Int]]
perms) = [(VName, [Int])] -> ([VName], [[Int]])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, [Int])]
arrs_and_perms
    [Type]
arr_ts <- (VName -> BinderT Kernels (State VNameSource) Type)
-> [VName] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrs
    let tile_ts :: [Type]
tile_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
arr_ts
        w :: SubExp
w = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
arr_ts

    let readTileElem :: VName -> [Int] -> BinderT Kernels (State VNameSource) VName
readTileElem VName
arr [Int]
perm =
          -- No need for fullSlice because we are tiling only prims.
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"tile_elem" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr
          [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ Result -> SubExp
forall a. [a] -> a
last (Result -> SubExp) -> Result -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i,SubExp
j]]
        readTileElemIfInBounds :: (Type, VName, [Int])
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
readTileElemIfInBounds (Type
tile_t, VName
arr, [Int]
perm) = do
          let idx :: SubExp
idx = Result -> SubExp
forall a. [a] -> a
last (Result -> SubExp) -> Result -> SubExp
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp
i,SubExp
j]
              othercheck :: PrimExp VName
othercheck = [PrimExp VName] -> PrimExp VName
forall a. [a] -> a
last ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ [Int] -> [PrimExp VName] -> [PrimExp VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm
                           [ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid_y PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
kdim_y
                           , VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid_x PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
kdim_x
                           ]
          BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
               PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
idx PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
w PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. PrimExp VName
othercheck)
            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]])
            ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Type
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Type -> m (Exp (Lore m))
eBlank Type
tile_t])

    ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
      case TileKind
kind of
        TileKind
TilePartial ->
          ((Type, VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
-> [(Type, VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"pre" (ExpT Kernels -> BinderT Kernels (State VNameSource) VName)
-> ((Type, VName, [Int])
    -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> (Type, VName, [Int])
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Type, VName, [Int])
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
(Type, VName, [Int])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
readTileElemIfInBounds) ([Type] -> [VName] -> [[Int]] -> [(Type, VName, [Int])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
tile_ts [VName]
arrs [[Int]]
perms)
        TileKind
TileFull ->
          (VName -> [Int] -> BinderT Kernels (State VNameSource) VName)
-> [VName]
-> [[Int]]
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> [Int] -> BinderT Kernels (State VNameSource) VName
readTileElem [VName]
arrs [[Int]]
perms

processTile2D :: (VName, VName) -> (VName, VName) -> (SubExp, SubExp) -> SubExp
              -> Count NumGroups SubExp -> Count GroupSize SubExp
              -> PrivStms
              -> Commutativity -> Lambda Kernels -> Lambda Kernels
              -> [(VName,[Int])] -> [VName]
              -> Binder Kernels [VName]
processTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
processTile2D
  (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size
  PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam [(VName, [Int])]
tiles_and_perms [VName]
accs = do

  -- Might be truncated in case of a partial tile.
  SubExp
actual_tile_size <- Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 ([Type] -> SubExp)
-> BinderT Kernels (State VNameSource) [Type]
-> BinderT Kernels (State VNameSource) SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, [Int]) -> BinderT Kernels (State VNameSource) Type)
-> [(VName, [Int])] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType (VName -> BinderT Kernels (State VNameSource) Type)
-> ((VName, [Int]) -> VName)
-> (VName, [Int])
-> BinderT Kernels (State VNameSource) Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, [Int]) -> VName
forall a b. (a, b) -> a
fst) [(VName, [Int])]
tiles_and_perms

  String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D String
"acc" (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
SegNoVirtFull)
    ResultManifest
ResultPrivate (SubExp
tile_size, SubExp
tile_size) (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
    SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)

    Slice SubExp -> PrivStms -> BinderT Kernels (State VNameSource) ()
addPrivStms [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y] PrivStms
privstms

    -- We replace the neutral elements with the accumulators (this is
    -- OK because the parallel semantics are not used after this
    -- point).
    Result
thread_accs <- [VName]
-> (VName -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
accs ((VName -> BinderT Kernels (State VNameSource) SubExp)
 -> BinderT Kernels (State VNameSource) Result)
-> (VName -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
acc ->
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"acc" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
acc [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y]
    let form' :: ScremaForm Kernels
form' = [Reduce Kernels] -> Lambda Kernels -> ScremaForm Kernels
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Commutativity -> Lambda Kernels -> Result -> Reduce Kernels
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
red_comm Lambda Kernels
red_lam Result
thread_accs] Lambda Kernels
map_lam

    [VName]
tiles' <- [(VName, [Int])]
-> ((VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(VName, [Int])]
tiles_and_perms (((VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
 -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, [Int]) -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
tile, [Int]
perm) -> do
      Type
tile_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
tile
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"tile" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Int -> Slice SubExp -> Slice SubExp
sliceAt Type
tile_t ([Int] -> Int
forall a. [a] -> a
head [Int]
perm)
        [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [Int] -> [VName] -> [VName]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [VName
ltid_x, VName
ltid_y]]

    ([VName] -> Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var) (BinderT Kernels (State VNameSource) [VName]
 -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
-> BinderT Kernels (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
      String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc" (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$
                               VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid_x PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
kdim_x PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
                               VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid_y PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
kdim_y)
      ([BinderT
   Kernels
   (State VNameSource)
   (Exp (Lore (BinderT Kernels (State VNameSource))))]
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT Kernels
 -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm Kernels -> [VName] -> SOAC Kernels
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
actual_tile_size ScremaForm Kernels
form' [VName]
tiles'])
      (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
thread_accs)

processResidualTile2D :: (VName, VName) -> (VName, VName) -> (SubExp, SubExp) -> SubExp
                      -> Count NumGroups SubExp -> Count GroupSize SubExp -> PrivStms
                      -> Commutativity -> Lambda Kernels -> Lambda Kernels
                      -> SubExp -> [VName] -> SubExp -> [(VName, [Int])]
                      -> Binder Kernels [VName]
processResidualTile2D :: (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> SubExp
-> [VName]
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile2D
  (VName, VName)
gids (VName, VName)
gtids (SubExp, SubExp)
kdims SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam
  SubExp
num_whole_tiles [VName]
accs SubExp
w [(VName, [Int])]
arrs_and_perms = do
  -- The number of residual elements that are not covered by
  -- the whole tiles.
  SubExp
residual_input <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"residual_input" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SRem IntType
Int32) SubExp
w SubExp
tile_size

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"acc_after_residual" (ExpT Kernels -> BinderT Kernels (State VNameSource) [VName])
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    BinderT
  Kernels
  (State VNameSource)
  (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp VName
 -> BinderT
      Kernels
      (State VNameSource)
      (Exp (Lore (BinderT Kernels (State VNameSource)))))
-> PrimExp VName
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
residual_input PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. PrimExp VName
0)
    (Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      Kernels
      (State VNameSource)
      (Body (Lore (BinderT Kernels (State VNameSource)))))
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Body (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
accs)
    (SubExp -> Binder Kernels (Body Kernels)
nonemptyTile SubExp
residual_input)

  where
    nonemptyTile :: SubExp -> Binder Kernels (Body Kernels)
nonemptyTile SubExp
residual_input = Body Kernels -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> Binder Kernels (Body Kernels))
-> (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels)
-> Binder Kernels (Body Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels))
-> Binder Kernels (Body Kernels) -> Binder Kernels (Body Kernels)
forall a b. (a -> b) -> a -> b
$ do
      -- Collectively construct a tile.  Threads that are out-of-bounds
      -- provide a blank dummy value.
      [VName]
full_tile <- (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
readTile2D (SubExp, SubExp)
kdims (VName, VName)
gtids (VName, VName)
gids SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size
                   TileKind
TilePartial PrivStms
privstms SubExp
num_whole_tiles [(VName, [Int])]
arrs_and_perms

      [VName]
tile <- [VName]
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
full_tile ((VName -> BinderT Kernels (State VNameSource) VName)
 -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) VName)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \VName
tile ->
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"partial_tile" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) VName)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
tile
        [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1),
         SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) SubExp
residual_input (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1)]

      -- Now each thread performs a traversal of the tile and
      -- updates its accumulator.
      Result -> Body Kernels
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body Kernels)
-> ([VName] -> Result) -> [VName] -> Body Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Body Kernels)
-> BinderT Kernels (State VNameSource) [VName]
-> Binder Kernels (Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
processTile2D (VName, VName)
gids (VName, VName)
gtids (SubExp, SubExp)
kdims SubExp
tile_size Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size
        PrivStms
privstms Commutativity
red_comm Lambda Kernels
red_lam Lambda Kernels
map_lam
        ([VName] -> [[Int]] -> [(VName, [Int])]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tile (((VName, [Int]) -> [Int]) -> [(VName, [Int])] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
map (VName, [Int]) -> [Int]
forall a b. (a, b) -> b
snd [(VName, [Int])]
arrs_and_perms)) [VName]
accs

tiling2d :: [(VName,SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp)
tiling2d [(VName, SubExp)]
dims_on_top SegLevel
_initial_lvl (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
w = do
  VName
gid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_x"
  VName
gid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_y"

  Name
tile_size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name)
-> BinderT Kernels (State VNameSource) VName
-> BinderT Kernels (State VNameSource) Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tile_size"
  SubExp
tile_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"tile_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
tile_size_key SizeClass
SizeTile
  SubExp
group_size <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
tile_size SubExp
tile_size

  SubExp
num_groups_x <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_x" (ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                  IntType
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eDivRoundingUp IntType
Int32 (SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
kdim_x) (SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
tile_size)
  SubExp
num_groups_y <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_y" (ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                  IntType
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eDivRoundingUp IntType
Int32 (SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
kdim_y) (SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
tile_size)

  SubExp
num_groups <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups_top" (ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
num_groups_x
                (SubExp
num_groups_y SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
dims_on_top)

  VName
gid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"
  let lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirtFull
      space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
              [(VName, SubExp)]
dims_on_top [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gid_x, SubExp
num_groups_x), (VName
gid_y, SubExp
num_groups_y)]

  -- Number of whole tiles that fit in the input.
  SubExp
num_whole_tiles <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_whole_tiles" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SQuot IntType
Int32) SubExp
w SubExp
tile_size
  Tiling -> Binder Kernels Tiling
forall (m :: * -> *) a. Monad m => a -> m a
return Tiling :: (String
 -> SegLevel
 -> ResultManifest
 -> (PrimExp VName
     -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> (TileKind
    -> PrivStms
    -> SubExp
    -> [(VName, [Int])]
    -> BinderT Kernels (State VNameSource) [VName])
-> (PrivStms
    -> Commutativity
    -> Lambda Kernels
    -> Lambda Kernels
    -> [(VName, [Int])]
    -> [VName]
    -> BinderT Kernels (State VNameSource) [VName])
-> (PrivStms
    -> Commutativity
    -> Lambda Kernels
    -> Lambda Kernels
    -> SubExp
    -> [VName]
    -> SubExp
    -> [(VName, [Int])]
    -> BinderT Kernels (State VNameSource) [VName])
-> (VName -> BinderT Kernels (State VNameSource) KernelResult)
-> SegSpace
-> Shape
-> SegLevel
-> SubExp
-> Tiling
Tiling
    { tilingSegMap :: String
-> SegLevel
-> ResultManifest
-> (PrimExp VName
    -> Slice SubExp -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
tilingSegMap = \String
desc SegLevel
lvl' ResultManifest
manifest PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f ->
        String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
segMap2D String
desc SegLevel
lvl' ResultManifest
manifest (SubExp
tile_size, SubExp
tile_size) (((VName, VName) -> BinderT Kernels (State VNameSource) Result)
 -> BinderT Kernels (State VNameSource) [VName])
-> ((VName, VName) -> BinderT Kernels (State VNameSource) Result)
-> BinderT Kernels (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
ltid_x, VName
ltid_y) -> do
        SubExp
-> (VName, VName)
-> (VName, VName)
-> (VName, VName)
-> BinderT Kernels (State VNameSource) ()
reconstructGtids2D SubExp
tile_size (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) (VName
ltid_x, VName
ltid_y)
        PrimExp VName
-> Slice SubExp -> BinderT Kernels (State VNameSource) Result
f (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid_x PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
kdim_x PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&.
           VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid_y PrimType
int32 PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
kdim_y)
          [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_x, SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
ltid_y]

    , tilingReadTile :: TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
tilingReadTile = (SubExp, SubExp)
-> (VName, VName)
-> (VName, VName)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TileKind
-> PrivStms
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
readTile2D (SubExp
kdim_x, SubExp
kdim_y) (VName
gtid_x, VName
gtid_y) (VName
gid_x, VName
gid_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
    , tilingProcessTile :: PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
tilingProcessTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> [(VName, [Int])]
-> [VName]
-> BinderT Kernels (State VNameSource) [VName]
processTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
    , tilingProcessResidualTile :: PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> SubExp
-> [VName]
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
tilingProcessResidualTile = (VName, VName)
-> (VName, VName)
-> (SubExp, SubExp)
-> SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> PrivStms
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> SubExp
-> [VName]
-> SubExp
-> [(VName, [Int])]
-> BinderT Kernels (State VNameSource) [VName]
processResidualTile2D (VName
gid_x, VName
gid_y) (VName
gtid_x, VName
gtid_y) (SubExp
kdim_x, SubExp
kdim_y) SubExp
tile_size (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

    , tilingTileReturns :: VName -> BinderT Kernels (State VNameSource) KernelResult
tilingTileReturns = [(VName, SubExp)]
-> [(SubExp, SubExp)]
-> VName
-> BinderT Kernels (State VNameSource) KernelResult
tileReturns [(VName, SubExp)]
dims_on_top [(SubExp
kdim_x, SubExp
tile_size), (SubExp
kdim_y, SubExp
tile_size)]

    , tilingTileShape :: Shape
tilingTileShape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tile_size, SubExp
tile_size]
    , tilingNumWholeTiles :: SubExp
tilingNumWholeTiles = SubExp
num_whole_tiles
    , tilingLevel :: SegLevel
tilingLevel = SegLevel
lvl
    , tilingSpace :: SegSpace
tilingSpace = SegSpace
space
    }

-- | The variance table keeps a mapping from a variable name
-- (something produced by a 'Stm') to the kernel thread indices
-- that name depends on.  If a variable is not present in this table,
-- that means it is bound outside the kernel (and so can be considered
-- invariant to all dimensions).
type VarianceTable = M.Map VName Names

varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms = (VarianceTable -> Stm Kernels -> VarianceTable)
-> VarianceTable -> Stms Kernels -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm

varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
variance Stm Kernels
bnd =
  (VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
bnd
  where add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
        look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
        binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
bnd)