{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
-- | Kernel extraction.
--
-- In the following, I will use the term "width" to denote the amount
-- of immediate parallelism in a map - that is, the outer size of the
-- array(s) being used as input.
--
-- = Basic Idea
--
-- If we have:
--
-- @
--   map
--     map(f)
--     bnds_a...
--     map(g)
-- @
--
-- Then we want to distribute to:
--
-- @
--   map
--     map(f)
--   map
--     bnds_a
--   map
--     map(g)
-- @
--
-- But for now only if
--
--  (0) it can be done without creating irregular arrays.
--      Specifically, the size of the arrays created by @map(f)@, by
--      @map(g)@ and whatever is created by @bnds_a@ that is also used
--      in @map(g)@, must be invariant to the outermost loop.
--
--  (1) the maps are _balanced_.  That is, the functions @f@ and @g@
--      must do the same amount of work for every iteration.
--
-- The advantage is that the map-nests containing @map(f)@ and
-- @map(g)@ can now be trivially flattened at no cost, thus exposing
-- more parallelism.  Note that the @bnds_a@ map constitutes array
-- expansion, which requires additional storage.
--
-- = Distributing Sequential Loops
--
-- As a starting point, sequential loops are treated like scalar
-- expressions.  That is, not distributed.  However, sometimes it can
-- be worthwhile to distribute if they contain a map:
--
-- @
--   map
--     loop
--       map
--     map
-- @
--
-- If we distribute the loop and interchange the outer map into the
-- loop, we get this:
--
-- @
--   loop
--     map
--       map
--   map
--     map
-- @
--
-- Now more parallelism may be available.
--
-- = Unbalanced Maps
--
-- Unbalanced maps will as a rule be sequentialised, but sometimes,
-- there is another way.  Assume we find this:
--
-- @
--   map
--     map(f)
--       map(g)
--     map
-- @
--
-- Presume that @map(f)@ is unbalanced.  By the simple rule above, we
-- would then fully sequentialise it, resulting in this:
--
-- @
--   map
--     loop
--   map
--     map
-- @
--
-- == Balancing by Loop Interchange
--
-- The above is not ideal, as we cannot flatten the @map-loop@ nest,
-- and we are thus limited in the amount of parallelism available.
--
-- But assume now that the width of @map(g)@ is invariant to the outer
-- loop.  Then if possible, we can interchange @map(f)@ and @map(g)@,
-- sequentialise @map(f)@ and distribute, interchanging the outer
-- parallel loop into the sequential loop:
--
-- @
--   loop(f)
--     map
--       map(g)
--   map
--     map
-- @
--
-- After flattening the two nests we can obtain more parallelism.
--
-- When distributing a map, we also need to distribute everything that
-- the map depends on - possibly as its own map.  When distributing a
-- set of scalar bindings, we will need to know which of the binding
-- results are used afterwards.  Hence, we will need to compute usage
-- information.
--
-- = Redomap
--
-- Redomap can be handled much like map.  Distributed loops are
-- distributed as maps, with the parameters corresponding to the
-- neutral elements added to their bodies.  The remaining loop will
-- remain a redomap.  Example:
--
-- @
-- redomap(op,
--         fn (v) =>
--           map(f)
--           map(g),
--         e,a)
-- @
--
-- distributes to
--
-- @
-- let b = map(fn v =>
--               let acc = e
--               map(f),
--               a)
-- redomap(op,
--         fn (v,dist) =>
--           map(g),
--         e,a,b)
-- @
--
-- Note that there may be further kernel extraction opportunities
-- inside the @map(f)@.  The downside of this approach is that the
-- intermediate array (@b@ above) must be written to main memory.  An
-- often better approach is to just turn the entire @redomap@ into a
-- single kernel.
--
module Futhark.Pass.ExtractKernels (extractKernels) where

import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Data.Maybe

import Prelude hiding (log)

import Futhark.Representation.SOACS
import Futhark.Representation.SOACS.Simplify (simplifyStms)
import qualified Futhark.Representation.Kernels as Out
import Futhark.Representation.Kernels.Kernel
import Futhark.MonadFreshNames
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename
import Futhark.Pass
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.Intragroup
import Futhark.Util
import Futhark.Util.Log

-- | Transform a program using SOACs to a program using explicit
-- kernels, using the kernel extraction transformation.
extractKernels :: Pass SOACS Out.Kernels
extractKernels :: Pass SOACS Kernels
extractKernels =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass { passName :: String
passName = String
"extract kernels"
       , passDescription :: String
passDescription = String
"Perform kernel extraction"
       , passFunction :: Prog SOACS -> PassM (Prog Kernels)
passFunction = Prog SOACS -> PassM (Prog Kernels)
transformProg
       }

transformProg :: Prog SOACS -> PassM (Prog Out.Kernels)
transformProg :: Prog SOACS -> PassM (Prog Kernels)
transformProg (Prog Stms SOACS
consts [FunDef SOACS]
funs) = do
  KernelsStms
consts' <- DistribM KernelsStms -> PassM KernelsStms
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM KernelsStms -> PassM KernelsStms)
-> DistribM KernelsStms -> PassM KernelsStms
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
forall a. Monoid a => a
mempty ([Stm] -> DistribM KernelsStms) -> [Stm] -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
consts
  [FunDef Kernels]
funs' <- (FunDef SOACS -> PassM (FunDef Kernels))
-> [FunDef SOACS] -> PassM [FunDef Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope Kernels -> FunDef SOACS -> PassM (FunDef Kernels)
forall (m :: * -> *).
(MonadFreshNames m, MonadLogger m) =>
Scope Kernels -> FunDef SOACS -> m (FunDef Kernels)
transformFunDef (Scope Kernels -> FunDef SOACS -> PassM (FunDef Kernels))
-> Scope Kernels -> FunDef SOACS -> PassM (FunDef Kernels)
forall a b. (a -> b) -> a -> b
$ KernelsStms -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf KernelsStms
consts') [FunDef SOACS]
funs
  Prog Kernels -> PassM (Prog Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog Kernels -> PassM (Prog Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall a b. (a -> b) -> a -> b
$ KernelsStms -> [FunDef Kernels] -> Prog Kernels
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog KernelsStms
consts' [FunDef Kernels]
funs'

-- In order to generate more stable threshold names, we keep track of
-- the numbers used for thresholds separately from the ordinary name
-- source,
data State = State { State -> VNameSource
stateNameSource :: VNameSource
                   , State -> Int
stateThresholdCounter :: Int
                   }

newtype DistribM a = DistribM (RWS (Scope Out.Kernels) Log State a)
                   deriving (a -> DistribM b -> DistribM a
(a -> b) -> DistribM a -> DistribM b
(forall a b. (a -> b) -> DistribM a -> DistribM b)
-> (forall a b. a -> DistribM b -> DistribM a) -> Functor DistribM
forall a b. a -> DistribM b -> DistribM a
forall a b. (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DistribM b -> DistribM a
$c<$ :: forall a b. a -> DistribM b -> DistribM a
fmap :: (a -> b) -> DistribM a -> DistribM b
$cfmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
Functor, Functor DistribM
a -> DistribM a
Functor DistribM
-> (forall a. a -> DistribM a)
-> (forall a b. DistribM (a -> b) -> DistribM a -> DistribM b)
-> (forall a b c.
    (a -> b -> c) -> DistribM a -> DistribM b -> DistribM c)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM a)
-> Applicative DistribM
DistribM a -> DistribM b -> DistribM b
DistribM a -> DistribM b -> DistribM a
DistribM (a -> b) -> DistribM a -> DistribM b
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DistribM a -> DistribM b -> DistribM a
$c<* :: forall a b. DistribM a -> DistribM b -> DistribM a
*> :: DistribM a -> DistribM b -> DistribM b
$c*> :: forall a b. DistribM a -> DistribM b -> DistribM b
liftA2 :: (a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
<*> :: DistribM (a -> b) -> DistribM a -> DistribM b
$c<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
pure :: a -> DistribM a
$cpure :: forall a. a -> DistribM a
$cp1Applicative :: Functor DistribM
Applicative, Applicative DistribM
a -> DistribM a
Applicative DistribM
-> (forall a b. DistribM a -> (a -> DistribM b) -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a. a -> DistribM a)
-> Monad DistribM
DistribM a -> (a -> DistribM b) -> DistribM b
DistribM a -> DistribM b -> DistribM b
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DistribM a
$creturn :: forall a. a -> DistribM a
>> :: DistribM a -> DistribM b -> DistribM b
$c>> :: forall a b. DistribM a -> DistribM b -> DistribM b
>>= :: DistribM a -> (a -> DistribM b) -> DistribM b
$c>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
$cp1Monad :: Applicative DistribM
Monad,
                             HasScope Out.Kernels, LocalScope Out.Kernels,
                             MonadState State,
                             Monad DistribM
Applicative DistribM
a -> DistribM ()
Applicative DistribM
-> Monad DistribM
-> (forall a. ToLog a => a -> DistribM ())
-> (Log -> DistribM ())
-> MonadLogger DistribM
Log -> DistribM ()
forall a. ToLog a => a -> DistribM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> (forall a. ToLog a => a -> m ())
-> (Log -> m ())
-> MonadLogger m
addLog :: Log -> DistribM ()
$caddLog :: Log -> DistribM ()
logMsg :: a -> DistribM ()
$clogMsg :: forall a. ToLog a => a -> DistribM ()
$cp2MonadLogger :: Monad DistribM
$cp1MonadLogger :: Applicative DistribM
MonadLogger)

instance MonadFreshNames DistribM where
  getNameSource :: DistribM VNameSource
getNameSource = (State -> VNameSource) -> DistribM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> DistribM ()
putNameSource VNameSource
src = (State -> State) -> DistribM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> DistribM ())
-> (State -> State) -> DistribM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s { stateNameSource :: VNameSource
stateNameSource = VNameSource
src }

runDistribM :: (MonadLogger m, MonadFreshNames m) =>
               DistribM a -> m a
runDistribM :: DistribM a -> m a
runDistribM (DistribM RWS (Scope Kernels) Log State a
m) = do
  (a
x, Log
msgs) <- (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Log), VNameSource)) -> m (a, Log))
-> (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let (a
x, State
s, Log
msgs) = RWS (Scope Kernels) Log State a
-> Scope Kernels -> State -> (a, State, Log)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Scope Kernels) Log State a
m Scope Kernels
forall a. Monoid a => a
mempty (VNameSource -> Int -> State
State VNameSource
src Int
0)
    in ((a
x, Log
msgs), State -> VNameSource
stateNameSource State
s)
  Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
msgs
  a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

transformFunDef :: (MonadFreshNames m, MonadLogger m) =>
                   Scope Out.Kernels -> FunDef SOACS
                -> m (Out.FunDef Out.Kernels)
transformFunDef :: Scope Kernels -> FunDef SOACS -> m (FunDef Kernels)
transformFunDef Scope Kernels
scope (FunDef Maybe EntryPoint
entry Name
name [RetType SOACS]
rettype [FParam SOACS]
params BodyT SOACS
body) = DistribM (FunDef Kernels) -> m (FunDef Kernels)
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM (FunDef Kernels) -> m (FunDef Kernels))
-> DistribM (FunDef Kernels) -> m (FunDef Kernels)
forall a b. (a -> b) -> a -> b
$ do
  Body Kernels
body' <- Scope Kernels -> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope Kernels
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (DistribM (Body Kernels) -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall a b. (a -> b) -> a -> b
$
           KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
forall a. Monoid a => a
mempty BodyT SOACS
body
  FunDef Kernels -> DistribM (FunDef Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef Kernels -> DistribM (FunDef Kernels))
-> FunDef Kernels -> DistribM (FunDef Kernels)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Name
-> [RetType Kernels]
-> [FParam Kernels]
-> Body Kernels
-> FunDef Kernels
forall lore.
Maybe EntryPoint
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Name
name [RetType SOACS]
[RetType Kernels]
rettype [FParam SOACS]
[FParam Kernels]
params Body Kernels
body'

transformBody :: KernelPath -> Body -> DistribM (Out.Body Out.Kernels)
transformBody :: KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
path BodyT SOACS
body = do KernelsStms
bnds <- KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path ([Stm] -> DistribM KernelsStms) -> [Stm] -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm]) -> Stms SOACS -> [Stm]
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body
                             Body Kernels -> DistribM (Body Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body Kernels -> DistribM (Body Kernels))
-> Body Kernels -> DistribM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody KernelsStms
bnds (Result -> Body Kernels) -> Result -> Body Kernels
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT SOACS
body

transformStms :: KernelPath -> [Stm] -> DistribM KernelsStms
transformStms :: KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
_ [] =
  KernelsStms -> DistribM KernelsStms
forall (m :: * -> *) a. Monad m => a -> m a
return KernelsStms
forall a. Monoid a => a
mempty
transformStms KernelPath
path (Stm
bnd:[Stm]
bnds) =
  Stm -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm Stm
bnd DistribM (Maybe (Stms SOACS))
-> (Maybe (Stms SOACS) -> DistribM KernelsStms)
-> DistribM KernelsStms
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Stms SOACS)
Nothing -> do
      KernelsStms
bnd' <- KernelPath -> Stm -> DistribM KernelsStms
transformStm KernelPath
path Stm
bnd
      KernelsStms -> DistribM KernelsStms -> DistribM KernelsStms
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf KernelsStms
bnd' (DistribM KernelsStms -> DistribM KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$
        (KernelsStms
bnd'KernelsStms -> KernelsStms -> KernelsStms
forall a. Semigroup a => a -> a -> a
<>) (KernelsStms -> KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path [Stm]
bnds
    Just Stms SOACS
bnds' ->
      KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path ([Stm] -> DistribM KernelsStms) -> [Stm] -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
bnds' [Stm] -> [Stm] -> [Stm]
forall a. Semigroup a => a -> a -> a
<> [Stm]
bnds

unbalancedLambda :: Lambda -> Bool
unbalancedLambda :: Lambda -> Bool
unbalancedLambda Lambda
lam =
  Names -> BodyT SOACS -> Bool
forall lore lore.
(Op lore ~ SOAC lore) =>
Names -> BodyT lore -> Bool
unbalancedBody
  ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) (BodyT SOACS -> Bool) -> BodyT SOACS -> Bool
forall a b. (a -> b) -> a -> b
$
  Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam

  where subExpBound :: SubExp -> Names -> Bool
subExpBound (Var VName
i) Names
bound = VName
i VName -> Names -> Bool
`nameIn` Names
bound
        subExpBound (Constant PrimValue
_) Names
_ = Bool
False

        unbalancedBody :: Names -> BodyT lore -> Bool
unbalancedBody Names
bound BodyT lore
body =
          (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Names -> ExpT lore -> Bool
unbalancedStm (Names
bound Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
body) (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Seq (Stm lore) -> Bool) -> Seq (Stm lore) -> Bool
forall a b. (a -> b) -> a -> b
$
          BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
body

        -- XXX - our notion of balancing is probably still too naive.
        unbalancedStm :: Names -> ExpT lore -> Bool
unbalancedStm Names
bound (Op (Stream w _ _ _)) =
          SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
        unbalancedStm Names
bound (Op (Screma w _ _)) =
          SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
        unbalancedStm Names
_ Op{} =
          Bool
False
        unbalancedStm Names
_ DoLoop{} = Bool
False

        unbalancedStm Names
bound (If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch IfAttr (BranchType lore)
_) =
          SubExp
cond SubExp -> Names -> Bool
`subExpBound` Names
bound Bool -> Bool -> Bool
&&
          (Names -> BodyT lore -> Bool
unbalancedBody Names
bound BodyT lore
tbranch Bool -> Bool -> Bool
|| Names -> BodyT lore -> Bool
unbalancedBody Names
bound BodyT lore
fbranch)

        unbalancedStm Names
_ (BasicOp BasicOp lore
_) =
          Bool
False
        unbalancedStm Names
_ (Apply Name
fname [(SubExp, Diet)]
_ [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) =
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> Bool
isBuiltInFunction Name
fname

sequentialisedUnbalancedStm :: Stm -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm :: Stm -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm (Let Pattern SOACS
pat StmAux (ExpAttr SOACS)
_ (Op soac :: Op SOACS
soac@(Screma _ form _)))
  | Just ([Reduce SOACS]
_, Lambda
lam2) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
    Lambda -> Bool
unbalancedLambda Lambda
lam2, Lambda -> Bool
lambdaContainsParallelism Lambda
lam2 = do
      Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
      Stms SOACS -> Maybe (Stms SOACS)
forall a. a -> Maybe a
Just (Stms SOACS -> Maybe (Stms SOACS))
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Maybe (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Maybe (Stms SOACS))
-> DistribM ((), Stms SOACS) -> DistribM (Maybe (Stms SOACS))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS DistribM))
-> SOAC (Lore (BinderT SOACS DistribM))
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat Op SOACS
SOAC (Lore (BinderT SOACS DistribM))
soac) Scope SOACS
types
sequentialisedUnbalancedStm Stm
_ =
  Maybe (Stms SOACS) -> DistribM (Maybe (Stms SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stms SOACS)
forall a. Maybe a
Nothing

cmpSizeLe :: String -> Out.SizeClass -> [SubExp]
          -> DistribM ((SubExp, Name), Out.Stms Out.Kernels)
cmpSizeLe :: String
-> SizeClass -> Result -> DistribM ((SubExp, Name), KernelsStms)
cmpSizeLe String
desc SizeClass
size_class Result
to_what = do
  Int
x <- (State -> Int) -> DistribM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Int
stateThresholdCounter
  (State -> State) -> DistribM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> DistribM ())
-> (State -> State) -> DistribM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s { stateThresholdCounter :: Int
stateThresholdCounter = Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}
  let size_key :: Name
size_key = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
x
  Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), KernelsStms)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (SubExp, Name)
 -> DistribM ((SubExp, Name), KernelsStms))
-> Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), KernelsStms)
forall a b. (a -> b) -> a -> b
$ do
    SubExp
to_what' <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"comparatee" (ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> BinOp
Mul IntType
Int32) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) Result
to_what
    SubExp
cmp_res <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
desc (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
size_key SizeClass
size_class SubExp
to_what'
    (SubExp, Name) -> Binder Kernels (SubExp, Name)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
cmp_res, Name
size_key)

kernelAlternatives :: (MonadFreshNames m, HasScope Out.Kernels m) =>
                      Out.Pattern Out.Kernels
                   -> Out.Body Out.Kernels
                   -> [(SubExp, Out.Body Out.Kernels)]
                   -> m (Out.Stms Out.Kernels)
kernelAlternatives :: Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m KernelsStms
kernelAlternatives Pattern Kernels
pat Body Kernels
default_body [] = Binder Kernels () -> m KernelsStms
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> m KernelsStms)
-> Binder Kernels () -> m KernelsStms
forall a b. (a -> b) -> a -> b
$ do
  Result
ses <- Body (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind Body (Lore (BinderT Kernels (State VNameSource)))
Body Kernels
default_body
  [(VName, SubExp)]
-> ((VName, SubExp) -> Binder Kernels ()) -> Binder Kernels ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern Kernels
pat) Result
ses) (((VName, SubExp) -> Binder Kernels ()) -> Binder Kernels ())
-> ((VName, SubExp) -> Binder Kernels ()) -> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ \(VName
name, SubExp
se) ->
    [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
name] (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp Kernels
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se
kernelAlternatives Pattern Kernels
pat Body Kernels
default_body ((SubExp
cond,Body Kernels
alt):[(SubExp, Body Kernels)]
alts) = Binder Kernels () -> m KernelsStms
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> m KernelsStms)
-> Binder Kernels () -> m KernelsStms
forall a b. (a -> b) -> a -> b
$ do
  PatternT Type
alts_pat <- ([PatElemT Type] -> PatternT Type)
-> BinderT Kernels (State VNameSource) [PatElemT Type]
-> BinderT Kernels (State VNameSource) (PatternT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern []) (BinderT Kernels (State VNameSource) [PatElemT Type]
 -> BinderT Kernels (State VNameSource) (PatternT Type))
-> BinderT Kernels (State VNameSource) [PatElemT Type]
-> BinderT Kernels (State VNameSource) (PatternT Type)
forall a b. (a -> b) -> a -> b
$ [PatElemT Type]
-> (PatElemT Type
    -> BinderT Kernels (State VNameSource) (PatElemT Type))
-> BinderT Kernels (State VNameSource) [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
Pattern Kernels
pat) ((PatElemT Type
  -> BinderT Kernels (State VNameSource) (PatElemT Type))
 -> BinderT Kernels (State VNameSource) [PatElemT Type])
-> (PatElemT Type
    -> BinderT Kernels (State VNameSource) (PatElemT Type))
-> BinderT Kernels (State VNameSource) [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \PatElemT Type
pe -> do
    VName
name <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> BinderT Kernels (State VNameSource) VName)
-> String -> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT Type
pe
    PatElemT Type
-> BinderT Kernels (State VNameSource) (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT Type
pe { patElemName :: VName
patElemName = VName
name }

  KernelsStms
alt_stms <- Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> BinderT Kernels (State VNameSource) KernelsStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m KernelsStms
kernelAlternatives PatternT Type
Pattern Kernels
alts_pat Body Kernels
default_body [(SubExp, Body Kernels)]
alts
  let alt_body :: Body Kernels
alt_body = KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody KernelsStms
alt_stms (Result -> Body Kernels) -> Result -> Body Kernels
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternValueNames PatternT Type
alts_pat

  Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body Kernels
-> Body Kernels
-> IfAttr (BranchType Kernels)
-> ExpT Kernels
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cond Body Kernels
alt Body Kernels
alt_body (IfAttr (BranchType Kernels) -> ExpT Kernels)
-> IfAttr (BranchType Kernels) -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ [Type] -> IfAttr ExtType
ifCommon ([Type] -> IfAttr ExtType) -> [Type] -> IfAttr ExtType
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern Kernels
pat

transformStm :: KernelPath -> Stm -> DistribM KernelsStms

transformStm :: KernelPath -> Stm -> DistribM KernelsStms
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpAttr SOACS)
aux (Op (CmpThreshold what s))) = do
  ((SubExp
r, Name
_), KernelsStms
stms) <- String
-> SizeClass -> Result -> DistribM ((SubExp, Name), KernelsStms)
cmpSizeLe String
s (KernelPath -> SizeClass
Out.SizeThreshold KernelPath
path) [SubExp
what]
  Binder Kernels () -> DistribM KernelsStms
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM KernelsStms)
-> Binder Kernels () -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
KernelsStms
stms
    Stm (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels ())
-> Stm (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern Kernels
pat StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp Kernels
forall lore. SubExp -> BasicOp lore
SubExp SubExp
r

transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpAttr SOACS)
aux (If SubExp
c BodyT SOACS
tb BodyT SOACS
fb IfAttr (BranchType SOACS)
rt)) = do
  Body Kernels
tb' <- KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
path BodyT SOACS
tb
  Body Kernels
fb' <- KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
path BodyT SOACS
fb
  KernelsStms -> DistribM KernelsStms
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelsStms -> DistribM KernelsStms)
-> KernelsStms -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> KernelsStms
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> KernelsStms) -> Stm Kernels -> KernelsStms
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern Kernels
pat StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body Kernels
-> Body Kernels
-> IfAttr (BranchType Kernels)
-> ExpT Kernels
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
c Body Kernels
tb' Body Kernels
fb' IfAttr (BranchType SOACS)
IfAttr (BranchType Kernels)
rt

transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpAttr SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
body)) =
  Scope Kernels -> DistribM KernelsStms -> DistribM KernelsStms
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope SOACS -> Scope Kernels
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (LoopForm SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm SOACS
form) Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<>
              [Param DeclType] -> Scope Kernels
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams [Param DeclType]
mergeparams) (DistribM KernelsStms -> DistribM KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$
    Stm Kernels -> KernelsStms
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> KernelsStms)
-> (Body Kernels -> Stm Kernels) -> Body Kernels -> KernelsStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern Kernels
pat StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> (Body Kernels -> ExpT Kernels) -> Body Kernels -> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> Body Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
ctx [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
val LoopForm Kernels
form' (Body Kernels -> KernelsStms)
-> DistribM (Body Kernels) -> DistribM KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
path BodyT SOACS
body
  where mergeparams :: [Param DeclType]
mergeparams = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val
        form' :: LoopForm Kernels
form' = case LoopForm SOACS
form of
                  WhileLoop VName
cond ->
                    VName -> LoopForm Kernels
forall lore. VName -> LoopForm lore
WhileLoop VName
cond
                  ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
ps ->
                    VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam Kernels, VName)]
ps

transformStm KernelPath
path (Let Pattern SOACS
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs)))
  | Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form =
      KernelPath -> MapLoop -> DistribM KernelsStms
onMap KernelPath
path (MapLoop -> DistribM KernelsStms)
-> MapLoop -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Pattern SOACS
-> Certificates -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern SOACS
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs

transformStm KernelPath
path (Let Pattern SOACS
res_pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs)))
  | Just (Lambda
scan_lam, Result
nes) <- ScremaForm SOACS -> Maybe (Lambda, Result)
forall lore. ScremaForm lore -> Maybe (Lambda lore, Result)
isScanSOAC ScremaForm SOACS
form,
    Just BinderT SOACS DistribM ()
do_iswim <- Pattern SOACS
-> SubExp
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern SOACS
-> SubExp -> Lambda -> [(SubExp, VName)] -> Maybe (m ())
iswim Pattern SOACS
res_pat SubExp
w Lambda
scan_lam ([(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
      Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
      KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path ([Stm] -> DistribM KernelsStms)
-> DistribM [Stm] -> DistribM KernelsStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm])
-> (((), Stms SOACS) -> Stms SOACS) -> ((), Stms SOACS) -> [Stm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> [Stm])
-> DistribM ((), Stms SOACS) -> DistribM [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Certificates
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs BinderT SOACS DistribM ()
do_iswim) Scope SOACS
types)

  -- We are only willing to generate code for scanomaps that do not
  -- involve array accumulators, and do not have parallelism in their
  -- map function.  Such cases will fall through to the
  -- screma-splitting case, and produce an ordinary map and scan.
  -- Hopefully, the scan then triggers the ISWIM case above (otherwise
  -- we will still crash in code generation).  However, if the map
  -- lambda is already identity, let's just go ahead here.
  | Just (Lambda
scan_lam, Result
nes, Lambda
map_lam) <- ScremaForm SOACS -> Maybe (Lambda, Result, Lambda)
forall lore.
ScremaForm lore -> Maybe (Lambda lore, Result, Lambda lore)
isScanomapSOAC ScremaForm SOACS
form,
    ((Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Lambda -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda
scan_lam) Bool -> Bool -> Bool
&&
     Bool -> Bool
not (Lambda -> Bool
lambdaContainsParallelism Lambda
map_lam)) Bool -> Bool -> Bool
|| Lambda -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda
map_lam = Binder Kernels () -> DistribM KernelsStms
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM KernelsStms)
-> Binder Kernels () -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ do
      let scan_lam' :: Lambda Kernels
scan_lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
scan_lam
          map_lam' :: Lambda Kernels
map_lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
map_lam
      SegLevel
lvl <- MkSegLevel (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel m
segThreadCapped [SubExp
w] String
"segscan" (ThreadRecommendation
 -> BinderT Kernels (State VNameSource) SegLevel)
-> ThreadRecommendation
-> BinderT Kernels (State VNameSource) SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      KernelsStms -> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (KernelsStms -> Binder Kernels ())
-> BinderT Kernels (State VNameSource) KernelsStms
-> Binder Kernels ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegLevel
-> Pattern Kernels
-> SubExp
-> Lambda Kernels
-> Lambda Kernels
-> Result
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT Kernels (State VNameSource) KernelsStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> Lambda Kernels
-> Lambda Kernels
-> Result
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m KernelsStms
segScan SegLevel
lvl Pattern SOACS
Pattern Kernels
res_pat SubExp
w Lambda Kernels
scan_lam' Lambda Kernels
map_lam' Result
nes [VName]
arrs [] []

transformStm KernelPath
path (Let Pattern SOACS
res_pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs)))
  | Just [Reduce Commutativity
comm Lambda
red_fun Result
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm SOACS
form,
    let comm' :: Commutativity
comm' | Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
red_fun = Commutativity
Commutative
              | Bool
otherwise                 = Commutativity
comm,
    Just BinderT SOACS DistribM ()
do_irwim <- Pattern SOACS
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern SOACS
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern SOACS
res_pat SubExp
w Commutativity
comm' Lambda
red_fun ([(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
      Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
      (SymbolTable (Wise SOACS)
_, Stms SOACS
bnds) <- ((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
-> (SymbolTable (Wise SOACS), Stms SOACS)
forall a b. (a, b) -> a
fst (((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
 -> (SymbolTable (Wise SOACS), Stms SOACS))
-> DistribM ((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
-> DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
-> Scope SOACS
-> DistribM ((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Stms SOACS
-> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms (Stms SOACS
 -> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS))
-> BinderT SOACS DistribM (Stms SOACS)
-> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> BinderT SOACS DistribM (Stms (Lore (BinderT SOACS DistribM)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (Certificates
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs BinderT SOACS DistribM ()
do_irwim)) Scope SOACS
types
      KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path ([Stm] -> DistribM KernelsStms) -> [Stm] -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
bnds

transformStm KernelPath
path (Let Pattern SOACS
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs)))
  | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form = do

  let paralleliseOuter :: DistribM KernelsStms
paralleliseOuter = Binder Kernels () -> DistribM KernelsStms
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM KernelsStms)
-> Binder Kernels () -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ do
        [SegRedOp Kernels]
red_ops <- [Reduce SOACS]
-> (Reduce SOACS
    -> BinderT Kernels (State VNameSource) (SegRedOp Kernels))
-> BinderT Kernels (State VNameSource) [SegRedOp Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds ((Reduce SOACS
  -> BinderT Kernels (State VNameSource) (SegRedOp Kernels))
 -> BinderT Kernels (State VNameSource) [SegRedOp Kernels])
-> (Reduce SOACS
    -> BinderT Kernels (State VNameSource) (SegRedOp Kernels))
-> BinderT Kernels (State VNameSource) [SegRedOp Kernels]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda
red_lam Result
nes) -> do
          (Lambda Kernels
red_lam', Result
nes', Shape
shape) <- Lambda
-> Result
-> BinderT
     Kernels (State VNameSource) (Lambda Kernels, Result, Shape)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
Lambda -> Result -> m (Lambda Kernels, Result, Shape)
determineReduceOp Lambda
red_lam Result
nes
          let comm' :: Commutativity
comm' | Lambda Kernels -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda Kernels
red_lam' = Commutativity
Commutative
                    | Bool
otherwise = Commutativity
comm
          SegRedOp Kernels
-> BinderT Kernels (State VNameSource) (SegRedOp Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegRedOp Kernels
 -> BinderT Kernels (State VNameSource) (SegRedOp Kernels))
-> SegRedOp Kernels
-> BinderT Kernels (State VNameSource) (SegRedOp Kernels)
forall a b. (a -> b) -> a -> b
$ Commutativity
-> Lambda Kernels -> Result -> Shape -> SegRedOp Kernels
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegRedOp lore
SegRedOp Commutativity
comm' Lambda Kernels
red_lam' Result
nes' Shape
shape
        let map_lam_sequential :: Lambda Kernels
map_lam_sequential = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
map_lam
        SegLevel
lvl <- MkSegLevel (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel m
segThreadCapped [SubExp
w] String
"segred" (ThreadRecommendation
 -> BinderT Kernels (State VNameSource) SegLevel)
-> ThreadRecommendation
-> BinderT Kernels (State VNameSource) SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
        KernelsStms -> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (KernelsStms -> Binder Kernels ())
-> BinderT Kernels (State VNameSource) KernelsStms
-> Binder Kernels ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
          ((Stm Kernels -> Stm Kernels) -> KernelsStms -> KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm Kernels -> Stm Kernels
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) (KernelsStms -> KernelsStms)
-> BinderT Kernels (State VNameSource) KernelsStms
-> BinderT Kernels (State VNameSource) KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
           SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) KernelsStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> m KernelsStms
nonSegRed SegLevel
lvl Pattern SOACS
Pattern Kernels
pat SubExp
w [SegRedOp Kernels]
red_ops Lambda Kernels
map_lam_sequential [VName]
arrs)

      outerParallelBody :: DistribM (Body Kernels)
outerParallelBody =
        Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
        (KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (KernelsStms -> Result -> Body Kernels)
-> DistribM KernelsStms -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistribM KernelsStms
paralleliseOuter DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))

      paralleliseInner :: KernelPath -> DistribM KernelsStms
paralleliseInner KernelPath
path' = do
        (Stm
mapbnd, Stm
redbnd) <- Pattern SOACS
-> (SubExp, Commutativity, Lambda, Lambda, Result, [VName])
-> DistribM (Stm, Stm)
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, ExpAttr lore ~ (),
 Op lore ~ SOAC lore) =>
Pattern lore
-> (SubExp, Commutativity, LambdaT lore, LambdaT lore, Result,
    [VName])
-> m (Stm lore, Stm lore)
redomapToMapAndReduce Pattern SOACS
pat (SubExp
w, Commutativity
comm', Lambda
red_lam, Lambda
map_lam, Result
nes, [VName]
arrs)
        KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path' [Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs Stm
mapbnd, Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs Stm
redbnd]
          where comm' :: Commutativity
comm' | Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
red_lam = Commutativity
Commutative
                      | Bool
otherwise = Commutativity
comm
                (Reduce Commutativity
comm Lambda
red_lam Result
nes) = [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds

      innerParallelBody :: KernelPath -> DistribM (Body Kernels)
innerParallelBody KernelPath
path' =
        Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
        (KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (KernelsStms -> Result -> Body Kernels)
-> DistribM KernelsStms -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM KernelsStms
paralleliseInner KernelPath
path' DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))

  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda -> Bool
lambdaContainsParallelism Lambda
map_lam
    then DistribM KernelsStms
paralleliseOuter
    else if Bool
incrementalFlattening then do
    ((SubExp
outer_suff, Name
outer_suff_key), KernelsStms
suff_stms) <-
      String
-> Result -> KernelPath -> DistribM ((SubExp, Name), KernelsStms)
sufficientParallelism String
"suff_outer_redomap" [SubExp
w] KernelPath
path

    Body Kernels
outer_stms <- DistribM (Body Kernels)
outerParallelBody
    Body Kernels
inner_stms <- KernelPath -> DistribM (Body Kernels)
innerParallelBody ((Name
outer_suff_key, Bool
False)(Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
:KernelPath
path)

    (KernelsStms
suff_stmsKernelsStms -> KernelsStms -> KernelsStms
forall a. Semigroup a => a -> a -> a
<>) (KernelsStms -> KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> DistribM KernelsStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m KernelsStms
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
inner_stms [(SubExp
outer_suff, Body Kernels
outer_stms)]
    else DistribM KernelsStms
paralleliseOuter

-- Streams can be handled in two different ways - either we
-- sequentialise the body or we keep it parallel and distribute.
transformStm KernelPath
path (Let Pattern SOACS
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Stream w (Parallel _ _ _ []) map_fun arrs))) = do
  -- No reduction part.  Remove the stream and leave the body
  -- parallel.  It will be distributed.
  Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
  KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path ([Stm] -> DistribM KernelsStms)
-> DistribM [Stm] -> DistribM KernelsStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm])
-> (((), Stms SOACS) -> Stms SOACS) -> ((), Stms SOACS) -> [Stm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> [Stm])
-> DistribM ((), Stms SOACS) -> DistribM [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Certificates
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT SOACS DistribM () -> BinderT SOACS DistribM ())
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT SOACS DistribM))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w [] LambdaT (Lore (BinderT SOACS DistribM))
Lambda
map_fun [VName]
arrs) Scope SOACS
types)

transformStm KernelPath
path (Let Pattern SOACS
pat aux :: StmAux (ExpAttr SOACS)
aux@(StmAux Certificates
cs ExpAttr SOACS
_) (Op (Stream w (Parallel o comm red_fun nes) fold_fun arrs)))
  | Bool
incrementalFlattening = do
      ((SubExp
outer_suff, Name
outer_suff_key), KernelsStms
suff_stms) <-
        String
-> Result -> KernelPath -> DistribM ((SubExp, Name), KernelsStms)
sufficientParallelism String
"suff_outer_stream" [SubExp
w] KernelPath
path

      Body Kernels
outer_stms <- KernelPath -> DistribM (Body Kernels)
outerParallelBody ((Name
outer_suff_key, Bool
True) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)
      Body Kernels
inner_stms <- KernelPath -> DistribM (Body Kernels)
innerParallelBody ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)

      (KernelsStms
suff_stmsKernelsStms -> KernelsStms -> KernelsStms
forall a. Semigroup a => a -> a -> a
<>) (KernelsStms -> KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> DistribM KernelsStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m KernelsStms
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
inner_stms [(SubExp
outer_suff, Body Kernels
outer_stms)]

  | Bool
otherwise = KernelPath -> DistribM KernelsStms
paralleliseOuter KernelPath
path

  where
    paralleliseOuter :: KernelPath -> DistribM KernelsStms
paralleliseOuter KernelPath
path'
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda
red_fun = do
          -- Split into a chunked map and a reduction, with the latter
          -- further transformed.
          let fold_fun' :: Lambda Kernels
fold_fun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
fold_fun

          let ([PatElemT Type]
red_pat_elems, [PatElemT Type]
concat_pat_elems) =
                Int -> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. Int -> [a] -> ([a], [a])
splitAt (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) ([PatElemT Type] -> ([PatElemT Type], [PatElemT Type]))
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT Type
Pattern SOACS
pat
              red_pat :: PatternT Type
red_pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type]
red_pat_elems

          ((SubExp
num_threads, [VName]
red_results), KernelsStms
stms) <-
            [String]
-> [PatElem Kernels]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> Result
-> [VName]
-> DistribM ((SubExp, [VName]), KernelsStms)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
[String]
-> [PatElem Kernels]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> Result
-> [VName]
-> m ((SubExp, [VName]), KernelsStms)
streamMap ((PatElemT Type -> String) -> [PatElemT Type] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> String
baseString (VName -> String)
-> (PatElemT Type -> VName) -> PatElemT Type -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
red_pat_elems) [PatElemT Type]
[PatElem Kernels]
concat_pat_elems SubExp
w
            Commutativity
Noncommutative Lambda Kernels
fold_fun' Result
nes [VName]
arrs

          ScremaForm SOACS
reduce_soac <- [Reduce SOACS] -> DistribM (ScremaForm SOACS)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Reduce lore] -> m (ScremaForm lore)
reduceSOAC [Commutativity -> Lambda -> Result -> Reduce SOACS
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
comm' Lambda
red_fun Result
nes]

          (KernelsStms
stmsKernelsStms -> KernelsStms -> KernelsStms
forall a. Semigroup a => a -> a -> a
<>) (KernelsStms -> KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
            KernelsStms -> DistribM KernelsStms -> DistribM KernelsStms
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf KernelsStms
stms
            (KernelPath -> Stm -> DistribM KernelsStms
transformStm KernelPath
path' (Stm -> DistribM KernelsStms) -> Stm -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Pattern SOACS -> StmAux (ExpAttr SOACS) -> ExpT SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern SOACS
red_pat StmAux (ExpAttr SOACS)
aux (ExpT SOACS -> Stm) -> ExpT SOACS -> Stm
forall a b. (a -> b) -> a -> b
$
             Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
num_threads ScremaForm SOACS
reduce_soac [VName]
red_results))

      | Bool
otherwise = do
          let red_fun_sequential :: Lambda Kernels
red_fun_sequential = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
red_fun
              fold_fun_sequential :: Lambda Kernels
fold_fun_sequential = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
fold_fun
          (Stm Kernels -> Stm Kernels) -> KernelsStms -> KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm Kernels -> Stm Kernels
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) (KernelsStms -> KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
            Pattern Kernels
-> SubExp
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> Result
-> [VName]
-> DistribM KernelsStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> SubExp
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> Result
-> [VName]
-> m KernelsStms
streamRed Pattern SOACS
Pattern Kernels
pat SubExp
w Commutativity
comm' Lambda Kernels
red_fun_sequential Lambda Kernels
fold_fun_sequential Result
nes [VName]
arrs

    outerParallelBody :: KernelPath -> DistribM (Body Kernels)
outerParallelBody KernelPath
path' =
      Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      (KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (KernelsStms -> Result -> Body Kernels)
-> DistribM KernelsStms -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM KernelsStms
paralleliseOuter KernelPath
path' DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))

    paralleliseInner :: KernelPath -> DistribM KernelsStms
paralleliseInner KernelPath
path' = do
      Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
      KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path' ([Stm] -> DistribM KernelsStms)
-> ([Stm] -> [Stm]) -> [Stm] -> DistribM KernelsStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm -> Stm) -> [Stm] -> [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) ([Stm] -> DistribM KernelsStms)
-> DistribM [Stm] -> DistribM KernelsStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
        (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm])
-> (((), Stms SOACS) -> Stms SOACS) -> ((), Stms SOACS) -> [Stm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> [Stm])
-> DistribM ((), Stms SOACS) -> DistribM [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS DistribM))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w Result
nes LambdaT (Lore (BinderT SOACS DistribM))
Lambda
fold_fun [VName]
arrs) Scope SOACS
types)

    innerParallelBody :: KernelPath -> DistribM (Body Kernels)
innerParallelBody KernelPath
path' =
      Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      (KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (KernelsStms -> Result -> Body Kernels)
-> DistribM KernelsStms -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM KernelsStms
paralleliseInner KernelPath
path' DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))

    comm' :: Commutativity
comm' | Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
red_fun, StreamOrd
o StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamOrd
InOrder = Commutativity
Commutative
          | Bool
otherwise                               = Commutativity
comm

transformStm KernelPath
path (Let Pattern SOACS
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) = do
  -- This screma is too complicated for us to immediately do
  -- anything, so split it up and try again.
  Scope SOACS
scope <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
  KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path ([Stm] -> DistribM KernelsStms)
-> (((), Stms SOACS) -> [Stm])
-> ((), Stms SOACS)
-> DistribM KernelsStms
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm -> Stm) -> [Stm] -> [Stm]
forall a b. (a -> b) -> [a] -> [b]
map (Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) ([Stm] -> [Stm])
-> (((), Stms SOACS) -> [Stm]) -> ((), Stms SOACS) -> [Stm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm])
-> (((), Stms SOACS) -> Stms SOACS) -> ((), Stms SOACS) -> [Stm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> DistribM KernelsStms)
-> DistribM ((), Stms SOACS) -> DistribM KernelsStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS DistribM))
-> SubExp
-> ScremaForm (Lore (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m), Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> ScremaForm (Lore m) -> [VName] -> m ()
dissectScrema Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w ScremaForm (Lore (BinderT SOACS DistribM))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope

transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpAttr SOACS)
_ (Op (Stream w (Sequential nes) fold_fun arrs))) = do
  -- Remove the stream and leave the body parallel.  It will be
  -- distributed.
  Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
  KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path ([Stm] -> DistribM KernelsStms)
-> DistribM [Stm] -> DistribM KernelsStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm])
-> (((), Stms SOACS) -> Stms SOACS) -> ((), Stms SOACS) -> [Stm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> [Stm])
-> DistribM ((), Stms SOACS) -> DistribM [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
      BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS DistribM))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w Result
nes LambdaT (Lore (BinderT SOACS DistribM))
Lambda
fold_fun [VName]
arrs) Scope SOACS
types)

transformStm KernelPath
_ (Let Pattern SOACS
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Scatter w lam ivs as))) = Binder Kernels () -> DistribM KernelsStms
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM KernelsStms)
-> Binder Kernels () -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ do
  let lam' :: Lambda Kernels
lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
lam
  VName
write_i <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
  let (Result
as_ws, [Int]
as_ns, [VName]
as_vs) = [(SubExp, Int, VName)] -> (Result, [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
as
      (Result
i_res, Result
v_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns) (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ Body Kernels -> Result
forall lore. BodyT lore -> Result
bodyResult (Body Kernels -> Result) -> Body Kernels -> Result
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'
      kstms :: KernelsStms
kstms = Body Kernels -> KernelsStms
forall lore. BodyT lore -> Stms lore
bodyStms (Body Kernels -> KernelsStms) -> Body Kernels -> KernelsStms
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'
      krets :: [KernelResult]
krets = do (SubExp
a_w, VName
a, [(SubExp, SubExp)]
is_vs) <- Result
-> [VName]
-> [[(SubExp, SubExp)]]
-> [(SubExp, VName, [(SubExp, SubExp)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
as_ws [VName]
as_vs ([[(SubExp, SubExp)]] -> [(SubExp, VName, [(SubExp, SubExp)])])
-> [[(SubExp, SubExp)]] -> [(SubExp, VName, [(SubExp, SubExp)])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([(SubExp, SubExp)] -> [[(SubExp, SubExp)]])
-> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a b. (a -> b) -> a -> b
$ Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
i_res Result
v_res
                 KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Result -> VName -> [(Result, SubExp)] -> KernelResult
WriteReturns [SubExp
a_w] VName
a [ ([SubExp
i],SubExp
v) | (SubExp
i,SubExp
v) <- [(SubExp, SubExp)]
is_vs ]
      body :: KernelBody Kernels
body = BodyAttr Kernels
-> KernelsStms -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () KernelsStms
kstms [KernelResult]
krets
      inputs :: [KernelInput]
inputs = do (Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam') [VName]
ivs
                  KernelInput -> [KernelInput]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> Result -> KernelInput
KernelInput (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
p) (Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]
  (SegOp Kernels
kernel, KernelsStms
stms) <-
    MkSegLevel (BinderT Kernels (State VNameSource))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody Kernels
-> BinderT Kernels (State VNameSource) (SegOp Kernels, KernelsStms)
forall (m :: * -> *).
(HasScope Kernels m, MonadFreshNames m) =>
MkSegLevel m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody Kernels
-> m (SegOp Kernels, KernelsStms)
mapKernel MkSegLevel (BinderT Kernels (State VNameSource))
forall (m :: * -> *). MonadFreshNames m => MkSegLevel m
segThreadCapped [(VName
write_i,SubExp
w)] [KernelInput]
inputs ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern SOACS
pat) KernelBody Kernels
body
  Certificates -> Binder Kernels () -> Binder Kernels ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (Binder Kernels () -> Binder Kernels ())
-> Binder Kernels () -> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
KernelsStms
stms
    Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern SOACS
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> Binder Kernels ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp SegOp Kernels
kernel

transformStm KernelPath
_ (Let Pattern SOACS
orig_pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Hist w ops bucket_fun imgs))) = do
  let bfun' :: Lambda Kernels
bfun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
bucket_fun

  -- It is important not to launch unnecessarily many threads for
  -- histograms, because it may mean we unnecessarily need to reduce
  -- subhistograms as well.
  Binder Kernels () -> DistribM KernelsStms
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM KernelsStms)
-> Binder Kernels () -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ do
    SegLevel
lvl <- MkSegLevel (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel m
segThreadCapped [SubExp
w] String
"seghist" (ThreadRecommendation
 -> BinderT Kernels (State VNameSource) SegLevel)
-> ThreadRecommendation
-> BinderT Kernels (State VNameSource) SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
    KernelsStms -> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (KernelsStms -> Binder Kernels ())
-> BinderT Kernels (State VNameSource) KernelsStms
-> Binder Kernels ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegLevel
-> Pattern SOACS
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) KernelsStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern SOACS
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda Kernels
-> [VName]
-> m KernelsStms
histKernel SegLevel
lvl Pattern SOACS
orig_pat [] [] Certificates
cs SubExp
w [HistOp SOACS]
ops Lambda Kernels
bfun' [VName]
imgs

transformStm KernelPath
_ Stm
bnd =
  Binder Kernels () -> DistribM KernelsStms
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM KernelsStms)
-> Binder Kernels () -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Stm -> Binder Kernels ()
forall (m :: * -> *).
(Transformer m, LetAttr (Lore m) ~ LetAttr SOACS) =>
Stm -> m ()
FOT.transformStmRecursively Stm
bnd

sufficientParallelism :: String -> [SubExp] -> KernelPath
                      -> DistribM ((SubExp, Name), Out.Stms Out.Kernels)
sufficientParallelism :: String
-> Result -> KernelPath -> DistribM ((SubExp, Name), KernelsStms)
sufficientParallelism String
desc Result
ws KernelPath
path = String
-> SizeClass -> Result -> DistribM ((SubExp, Name), KernelsStms)
cmpSizeLe String
desc (KernelPath -> SizeClass
Out.SizeThreshold KernelPath
path) Result
ws

-- | Returns the sizes of nested parallelism.
nestedParallelism :: Body -> [SubExp]
nestedParallelism :: BodyT SOACS -> Result
nestedParallelism = (Stm -> Result) -> Stms SOACS -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (ExpT SOACS -> Result
parallelism (ExpT SOACS -> Result) -> (Stm -> ExpT SOACS) -> Stm -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> ExpT SOACS
forall lore. Stm lore -> Exp lore
stmExp) (Stms SOACS -> Result)
-> (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms
  where parallelism :: ExpT SOACS -> Result
parallelism (Op (Scatter w _ _ _)) = [SubExp
w]
        parallelism (Op (Screma w _ _)) = [SubExp
w]
        parallelism (Op (Hist w _ _ _)) = [SubExp
w]
        parallelism (Op (Stream w Sequential{} lam _))
          | LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam =
              let update :: SubExp -> SubExp
update (Var VName
v) | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
LParam SOACS
chunk_size_param = SubExp
w
                  update SubExp
se = SubExp
se
              in (SubExp -> SubExp) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
update (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Result
nestedParallelism (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam
        parallelism (DoLoop [(FParam SOACS, SubExp)]
_ [(FParam SOACS, SubExp)]
_ LoopForm SOACS
_ BodyT SOACS
body) = BodyT SOACS -> Result
nestedParallelism BodyT SOACS
body
        parallelism ExpT SOACS
_ = []

-- | Intra-group parallelism is worthwhile if the lambda contains
-- non-map nested parallelism, or any nested parallelism inside a
-- loop.
worthIntraGroup :: Lambda -> Bool
worthIntraGroup :: Lambda -> Bool
worthIntraGroup Lambda
lam = BodyT SOACS -> Bool
interesting (BodyT SOACS -> Bool) -> BodyT SOACS -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam
  where interesting :: BodyT SOACS -> Bool
interesting BodyT SOACS
body = Bool -> Bool
not (Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Result
nestedParallelism BodyT SOACS
body) Bool -> Bool -> Bool
&&
                           Bool -> Bool
not (Stms SOACS -> Bool
onlyMaps (Stms SOACS -> Bool) -> Stms SOACS -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body)
        onlyMaps :: Stms SOACS -> Bool
onlyMaps = (Stm -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Stm -> Bool) -> Stms SOACS -> Bool)
-> (Stm -> Bool) -> Stms SOACS -> Bool
forall a b. (a -> b) -> a -> b
$ ExpT SOACS -> Bool
isMapOrSeq (ExpT SOACS -> Bool) -> (Stm -> ExpT SOACS) -> Stm -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> ExpT SOACS
forall lore. Stm lore -> Exp lore
stmExp
        isMapOrSeq :: ExpT SOACS -> Bool
isMapOrSeq (Op (Screma _ form@(ScremaForm _ _ lam') _))
          | Maybe Lambda -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Lambda -> Bool) -> Maybe Lambda -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda -> Bool
worthIntraGroup Lambda
lam'
        isMapOrSeq (Op Scatter{}) = Bool
True -- Basically a map.
        isMapOrSeq (DoLoop [(FParam SOACS, SubExp)]
_ [(FParam SOACS, SubExp)]
_ LoopForm SOACS
_ BodyT SOACS
body) =
          Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Result
nestedParallelism BodyT SOACS
body
        isMapOrSeq (Op Op SOACS
_) = Bool
False
        isMapOrSeq ExpT SOACS
_ = Bool
True

-- | A lambda is worth sequentialising if it contains nested
-- parallelism of an interesting kind.
worthSequentialising :: Lambda -> Bool
worthSequentialising :: Lambda -> Bool
worthSequentialising Lambda
lam = BodyT SOACS -> Bool
forall lore. (Op lore ~ SOAC SOACS) => BodyT lore -> Bool
interesting (BodyT SOACS -> Bool) -> BodyT SOACS -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam
  where interesting :: BodyT lore -> Bool
interesting BodyT lore
body = (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp lore -> Bool
interesting' (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) (Seq (Stm lore) -> Bool) -> Seq (Stm lore) -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
body
        interesting' :: Exp lore -> Bool
interesting' (Op (Screma _ form@(ScremaForm _ _ lam') _))
          | Maybe Lambda -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Lambda -> Bool) -> Maybe Lambda -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form = Lambda -> Bool
worthSequentialising Lambda
lam'
        interesting' (Op Scatter{}) = Bool
False -- Basically a map.
        interesting' (DoLoop [(FParam lore, SubExp)]
_ [(FParam lore, SubExp)]
_ LoopForm lore
_ BodyT lore
body) = BodyT lore -> Bool
interesting BodyT lore
body
        interesting' (Op Op lore
_) = Bool
True
        interesting' Exp lore
_ = Bool
False


onTopLevelStms :: KernelPath -> Stms SOACS -> DistNestT DistribM KernelsStms
onTopLevelStms :: KernelPath -> Stms SOACS -> DistNestT DistribM KernelsStms
onTopLevelStms KernelPath
path Stms SOACS
stms = do
  Scope Kernels
scope <- DistNestT DistribM (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  DistribM KernelsStms -> DistNestT DistribM KernelsStms
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistribM KernelsStms -> DistNestT DistribM KernelsStms)
-> DistribM KernelsStms -> DistNestT DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Scope Kernels -> DistribM KernelsStms -> DistribM KernelsStms
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (DistribM KernelsStms -> DistribM KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm] -> DistribM KernelsStms
transformStms KernelPath
path ([Stm] -> DistribM KernelsStms) -> [Stm] -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms

onMap :: KernelPath -> MapLoop -> DistribM KernelsStms
onMap :: KernelPath -> MapLoop -> DistribM KernelsStms
onMap KernelPath
path (MapLoop Pattern SOACS
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs) = do
  Scope Kernels
types <- DistribM (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let loopnest :: LoopNesting
loopnest = Pattern Kernels
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pattern SOACS
Pattern Kernels
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
      env :: KernelPath -> DistEnv DistribM
env KernelPath
path' = DistEnv :: forall (m :: * -> *).
Nestings
-> Scope Kernels
-> (Stms SOACS -> DistNestT m KernelsStms)
-> (MapLoop -> DistAcc -> DistNestT m DistAcc)
-> MkSegLevel m
-> DistEnv m
DistEnv
                  { distNest :: Nestings
distNest = Nesting -> Nestings
singleNesting (Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest)
                  , distScope :: Scope Kernels
distScope = PatternT Type -> Scope Kernels
forall lore attr.
(LetAttr lore ~ attr) =>
PatternT attr -> Scope lore
scopeOfPattern PatternT Type
Pattern SOACS
pat Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<>
                                Scope SOACS -> Scope Kernels
scopeForKernels (Lambda -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda
lam) Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<>
                                Scope Kernels
types
                  , distOnInnerMap :: MapLoop -> DistAcc -> DistNestT DistribM DistAcc
distOnInnerMap = KernelPath -> MapLoop -> DistAcc -> DistNestT DistribM DistAcc
onInnerMap KernelPath
path'
                  , distOnTopLevelStms :: Stms SOACS -> DistNestT DistribM KernelsStms
distOnTopLevelStms = KernelPath -> Stms SOACS -> DistNestT DistribM KernelsStms
onTopLevelStms KernelPath
path'
                  , distSegLevel :: MkSegLevel DistribM
distSegLevel = MkSegLevel DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel m
segThreadCapped
                  }
      exploitInnerParallelism :: KernelPath -> DistribM KernelsStms
exploitInnerParallelism KernelPath
path' =
        DistEnv DistribM
-> DistNestT DistribM DistAcc -> DistribM KernelsStms
forall (m :: * -> *).
MonadLogger m =>
DistEnv m -> DistNestT m DistAcc -> m KernelsStms
runDistNestT (KernelPath -> DistEnv DistribM
env KernelPath
path') (DistNestT DistribM DistAcc -> DistribM KernelsStms)
-> DistNestT DistribM DistAcc -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$
        DistAcc -> Stms SOACS -> DistNestT DistribM DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> Stms SOACS -> DistNestT m DistAcc
distributeMapBodyStms DistAcc
acc (BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam)

  if Bool -> Bool
not Bool
incrementalFlattening then KernelPath -> DistribM KernelsStms
exploitInnerParallelism KernelPath
path
    else do

    let exploitOuterParallelism :: KernelPath -> DistribM KernelsStms
exploitOuterParallelism KernelPath
path' = do
          let lam' :: Lambda Kernels
lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
lam
          DistEnv DistribM
-> DistNestT DistribM DistAcc -> DistribM KernelsStms
forall (m :: * -> *).
MonadLogger m =>
DistEnv m -> DistNestT m DistAcc -> m KernelsStms
runDistNestT (KernelPath -> DistEnv DistribM
env KernelPath
path') (DistNestT DistribM DistAcc -> DistribM KernelsStms)
-> DistNestT DistribM DistAcc -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ DistAcc -> DistNestT DistribM DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> DistNestT m DistAcc
distribute (DistAcc -> DistNestT DistribM DistAcc)
-> DistAcc -> DistNestT DistribM DistAcc
forall a b. (a -> b) -> a -> b
$
            KernelsStms -> DistAcc -> DistAcc
addStmsToKernel (Body Kernels -> KernelsStms
forall lore. BodyT lore -> Stms lore
bodyStms (Body Kernels -> KernelsStms) -> Body Kernels -> KernelsStms
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam') DistAcc
acc

    KernelNest
-> KernelPath
-> (KernelPath -> DistribM KernelsStms)
-> (KernelPath -> DistribM KernelsStms)
-> Pattern SOACS
-> Lambda
-> DistribM KernelsStms
onMap' (LoopNesting -> KernelNest
newKernel LoopNesting
loopnest) KernelPath
path KernelPath -> DistribM KernelsStms
exploitOuterParallelism KernelPath -> DistribM KernelsStms
exploitInnerParallelism Pattern SOACS
pat Lambda
lam
    where acc :: DistAcc
acc = DistAcc :: Targets -> KernelsStms -> DistAcc
DistAcc { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pattern SOACS
Pattern Kernels
pat, BodyT SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam)
                        , distStms :: KernelsStms
distStms = KernelsStms
forall a. Monoid a => a
mempty
                        }

onMap' :: KernelNest -> KernelPath
       -> (KernelPath -> DistribM (Out.Stms Out.Kernels))
       -> (KernelPath -> DistribM (Out.Stms Out.Kernels))
       -> Pattern
       -> Lambda
       -> DistribM (Out.Stms Out.Kernels)
onMap' :: KernelNest
-> KernelPath
-> (KernelPath -> DistribM KernelsStms)
-> (KernelPath -> DistribM KernelsStms)
-> Pattern SOACS
-> Lambda
-> DistribM KernelsStms
onMap' KernelNest
loopnest KernelPath
path KernelPath -> DistribM KernelsStms
mk_seq_stms KernelPath -> DistribM KernelsStms
mk_par_stms Pattern SOACS
pat Lambda
lam = do
  let nest_ws :: Result
nest_ws = KernelNest -> Result
kernelNestWidths KernelNest
loopnest
      res :: Result
res = (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern SOACS
pat

  Scope Kernels
types <- DistribM (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  ((SubExp
outer_suff, Name
outer_suff_key), KernelsStms
outer_suff_stms) <-
    String
-> Result -> KernelPath -> DistribM ((SubExp, Name), KernelsStms)
sufficientParallelism String
"suff_outer_par" Result
nest_ws KernelPath
path

  Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms)
intra <- if Lambda -> Bool
worthIntraGroup Lambda
lam then
             (ReaderT
   (Scope Kernels)
   DistribM
   (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
 -> Scope Kernels
 -> DistribM
      (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms)))
-> Scope Kernels
-> ReaderT
     (Scope Kernels)
     DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
-> DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT
  (Scope Kernels)
  DistribM
  (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
-> Scope Kernels
-> DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT Scope Kernels
types (ReaderT
   (Scope Kernels)
   DistribM
   (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
 -> DistribM
      (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms)))
-> ReaderT
     (Scope Kernels)
     DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
-> DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
forall a b. (a -> b) -> a -> b
$ KernelNest
-> Lambda
-> ReaderT
     (Scope Kernels)
     DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
forall (m :: * -> *).
(MonadFreshNames m, LocalScope Kernels m) =>
KernelNest
-> Lambda
-> m (Maybe
        ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
intraGroupParallelise KernelNest
loopnest Lambda
lam
           else Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms)
-> DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms)
forall a. Maybe a
Nothing
  Body Kernels
seq_body <- Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (KernelsStms -> Result -> Body Kernels)
-> DistribM KernelsStms -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
              KernelPath -> DistribM KernelsStms
mk_seq_stms ((Name
outer_suff_key, Bool
True) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
  let seq_alts :: [(SubExp, Body Kernels)]
seq_alts = [(SubExp
outer_suff, Body Kernels
seq_body) | Lambda -> Bool
worthSequentialising Lambda
lam]

  case Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms)
intra of
    Maybe ((SubExp, SubExp), SubExp, Log, KernelsStms, KernelsStms)
Nothing -> do
      Body Kernels
par_body <- Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (KernelsStms -> Result -> Body Kernels)
-> DistribM KernelsStms -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                  KernelPath -> DistribM KernelsStms
mk_par_stms ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

      (KernelsStms
outer_suff_stmsKernelsStms -> KernelsStms -> KernelsStms
forall a. Semigroup a => a -> a -> a
<>) (KernelsStms -> KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> DistribM KernelsStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m KernelsStms
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
par_body [(SubExp, Body Kernels)]
seq_alts

    Just ((SubExp
_intra_min_par, SubExp
intra_avail_par), SubExp
group_size, Log
log, KernelsStms
intra_prelude, KernelsStms
intra_stms) -> do
      Log -> DistribM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
      -- We must check that all intra-group parallelism fits in a group.
      ((SubExp
intra_ok, Name
intra_suff_key), KernelsStms
intra_suff_stms) <- do

        ((SubExp
intra_suff, Name
suff_key), KernelsStms
check_suff_stms) <-
          String
-> Result -> KernelPath -> DistribM ((SubExp, Name), KernelsStms)
sufficientParallelism String
"suff_intra_par" [SubExp
intra_avail_par] (KernelPath -> DistribM ((SubExp, Name), KernelsStms))
-> KernelPath -> DistribM ((SubExp, Name), KernelsStms)
forall a b. (a -> b) -> a -> b
$
          (Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path

        Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), KernelsStms)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (SubExp, Name)
 -> DistribM ((SubExp, Name), KernelsStms))
-> Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), KernelsStms)
forall a b. (a -> b) -> a -> b
$ do

          Stms (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
KernelsStms
intra_prelude

          SubExp
max_group_size <-
            String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"max_group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
Out.GetSizeMax SizeClass
Out.SizeGroup
          SubExp
fits <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"fits" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
                  CmpOp -> SubExp -> SubExp -> BasicOp Kernels
forall lore. CmpOp -> SubExp -> SubExp -> BasicOp lore
CmpOp (IntType -> CmpOp
CmpSle IntType
Int32) SubExp
group_size SubExp
max_group_size

          Stms (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
KernelsStms
check_suff_stms

          SubExp
intra_ok <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"intra_suff_and_fits" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp Kernels
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp BinOp
LogAnd SubExp
fits SubExp
intra_suff
          (SubExp, Name) -> Binder Kernels (SubExp, Name)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
intra_ok, Name
suff_key)

      Body Kernels
group_par_body <- Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> Body Kernels -> DistribM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody KernelsStms
intra_stms Result
res

      Body Kernels
par_body <- Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelsStms -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (KernelsStms -> Result -> Body Kernels)
-> DistribM KernelsStms -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                  KernelPath -> DistribM KernelsStms
mk_par_stms ([(Name
outer_suff_key, Bool
False),
                                (Name
intra_suff_key, Bool
False)]
                                KernelPath -> KernelPath -> KernelPath
forall a. [a] -> [a] -> [a]
++ KernelPath
path) DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

      ((KernelsStms
outer_suff_stmsKernelsStms -> KernelsStms -> KernelsStms
forall a. Semigroup a => a -> a -> a
<>KernelsStms
intra_suff_stms)KernelsStms -> KernelsStms -> KernelsStms
forall a. Semigroup a => a -> a -> a
<>) (KernelsStms -> KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> DistribM KernelsStms
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m KernelsStms
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
par_body ([(SubExp, Body Kernels)]
seq_alts [(SubExp, Body Kernels)]
-> [(SubExp, Body Kernels)] -> [(SubExp, Body Kernels)]
forall a. [a] -> [a] -> [a]
++ [(SubExp
intra_ok, Body Kernels
group_par_body)])

onInnerMap :: KernelPath -> MapLoop -> DistAcc -> DistNestT DistribM DistAcc
onInnerMap :: KernelPath -> MapLoop -> DistAcc -> DistNestT DistribM DistAcc
onInnerMap KernelPath
path maploop :: MapLoop
maploop@(MapLoop Pattern SOACS
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs) DistAcc
acc
  | Lambda -> Bool
unbalancedLambda Lambda
lam, Lambda -> Bool
lambdaContainsParallelism Lambda
lam =
      Stm -> DistAcc -> DistNestT DistribM DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel (MapLoop -> Stm
mapLoopStm MapLoop
maploop) DistAcc
acc
  | Bool -> Bool
not Bool
incrementalFlattening =
      MapLoop -> DistAcc -> DistNestT DistribM DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
MapLoop -> DistAcc -> DistNestT m DistAcc
distributeMap MapLoop
maploop DistAcc
acc
  | Bool
otherwise =
      DistAcc
-> Stm
-> DistNestT
     DistribM (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc (MapLoop -> Stm
mapLoopStm MapLoop
maploop) DistNestT
  DistribM (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT DistribM DistAcc)
-> DistNestT DistribM DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostKernels
post_kernels, Result
res, KernelNest
nest, DistAcc
acc')
        | Just ([Int]
perm, [PatElem]
_pat_unused) <- Pattern SOACS -> Result -> Maybe ([Int], [PatElem])
permutationAndMissing Pattern SOACS
pat Result
res -> do
            PostKernels -> DistNestT DistribM ()
forall (m :: * -> *). Monad m => PostKernels -> DistNestT m ()
addKernels PostKernels
post_kernels
            [Int] -> KernelNest -> DistAcc -> DistNestT DistribM DistAcc
multiVersion [Int]
perm KernelNest
nest DistAcc
acc'
      Maybe (PostKernels, Result, KernelNest, DistAcc)
_ -> MapLoop -> DistAcc -> DistNestT DistribM DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
MapLoop -> DistAcc -> DistNestT m DistAcc
distributeMap MapLoop
maploop DistAcc
acc

  where
    discardTargets :: DistAcc -> DistAcc
discardTargets DistAcc
acc' =
      -- FIXME: work around bogus targets.
      DistAcc
acc' { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pattern Kernels
forall a. Monoid a => a
mempty, Result
forall a. Monoid a => a
mempty) }

    multiVersion :: [Int] -> KernelNest -> DistAcc -> DistNestT DistribM DistAcc
multiVersion [Int]
perm KernelNest
nest DistAcc
acc' = do
      -- The kernel can be distributed by itself, so now we can
      -- decide whether to just sequentialise, or exploit inner
      -- parallelism.
      DistEnv DistribM
dist_env <- DistNestT DistribM (DistEnv DistribM)
forall r (m :: * -> *). MonadReader r m => m r
ask
      let extra_scope :: Scope Kernels
extra_scope = Targets -> Scope Kernels
targetsScope (Targets -> Scope Kernels) -> Targets -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ DistAcc -> Targets
distTargets DistAcc
acc'
      Scope Kernels
scope <- (Scope Kernels
extra_scopeScope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<>) (Scope Kernels -> Scope Kernels)
-> DistNestT DistribM (Scope Kernels)
-> DistNestT DistribM (Scope Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistNestT DistribM (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope

      KernelsStms
stms <- DistribM KernelsStms -> DistNestT DistribM KernelsStms
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistribM KernelsStms -> DistNestT DistribM KernelsStms)
-> DistribM KernelsStms -> DistNestT DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Scope Kernels -> DistribM KernelsStms -> DistribM KernelsStms
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (DistribM KernelsStms -> DistribM KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ do
        let maploop' :: MapLoop
maploop' = Pattern SOACS
-> Certificates -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern SOACS
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs

            exploitInnerParallelism :: KernelPath -> DistribM KernelsStms
exploitInnerParallelism KernelPath
path' = do
              let dist_env' :: DistEnv DistribM
dist_env' =
                    DistEnv DistribM
dist_env { distOnTopLevelStms :: Stms SOACS -> DistNestT DistribM KernelsStms
distOnTopLevelStms = KernelPath -> Stms SOACS -> DistNestT DistribM KernelsStms
onTopLevelStms KernelPath
path'
                             , distOnInnerMap :: MapLoop -> DistAcc -> DistNestT DistribM DistAcc
distOnInnerMap = KernelPath -> MapLoop -> DistAcc -> DistNestT DistribM DistAcc
onInnerMap KernelPath
path'
                             }
              DistEnv DistribM
-> DistNestT DistribM DistAcc -> DistribM KernelsStms
forall (m :: * -> *).
MonadLogger m =>
DistEnv m -> DistNestT m DistAcc -> m KernelsStms
runDistNestT DistEnv DistribM
dist_env' (DistNestT DistribM DistAcc -> DistribM KernelsStms)
-> DistNestT DistribM DistAcc -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$
                KernelNest
-> DistNestT DistribM DistAcc -> DistNestT DistribM DistAcc
forall (m :: * -> *) a.
Monad m =>
KernelNest -> DistNestT m a -> DistNestT m a
inNesting KernelNest
nest (DistNestT DistribM DistAcc -> DistNestT DistribM DistAcc)
-> DistNestT DistribM DistAcc -> DistNestT DistribM DistAcc
forall a b. (a -> b) -> a -> b
$ Scope Kernels
-> DistNestT DistribM DistAcc -> DistNestT DistribM DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
extra_scope (DistNestT DistribM DistAcc -> DistNestT DistribM DistAcc)
-> DistNestT DistribM DistAcc -> DistNestT DistribM DistAcc
forall a b. (a -> b) -> a -> b
$
                DistAcc -> DistAcc
discardTargets (DistAcc -> DistAcc)
-> DistNestT DistribM DistAcc -> DistNestT DistribM DistAcc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MapLoop -> DistAcc -> DistNestT DistribM DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
MapLoop -> DistAcc -> DistNestT m DistAcc
distributeMap MapLoop
maploop' DistAcc
acc { distStms :: KernelsStms
distStms = KernelsStms
forall a. Monoid a => a
mempty }

        -- Normally the permutation is for the output pattern, but
        -- we can't really change that, so we change the result
        -- order instead.
        let lam_res' :: Result
lam_res' = [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam
            lam' :: Lambda
lam' = Lambda
lam { lambdaBody :: BodyT SOACS
lambdaBody = (Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) { bodyResult :: Result
bodyResult = Result
lam_res' } }
            map_nesting :: LoopNesting
map_nesting = Pattern Kernels
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pattern SOACS
Pattern Kernels
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
            nest' :: KernelNest
nest' = Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (Pattern SOACS
Pattern Kernels
pat, Result
lam_res') LoopNesting
map_nesting KernelNest
nest

        -- XXX: we do not construct a new KernelPath when
        -- sequentialising.  This is only OK as long as further
        -- versioning does not take place down that branch (it currently
        -- does not).
        (Stm Kernels
sequentialised_kernel, KernelsStms
nestw_bnds) <- Scope Kernels
-> DistribM (Stm Kernels, KernelsStms)
-> DistribM (Stm Kernels, KernelsStms)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
extra_scope (DistribM (Stm Kernels, KernelsStms)
 -> DistribM (Stm Kernels, KernelsStms))
-> DistribM (Stm Kernels, KernelsStms)
-> DistribM (Stm Kernels, KernelsStms)
forall a b. (a -> b) -> a -> b
$ do
          let sequentialised_lam :: Lambda Kernels
sequentialised_lam = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
lam'
          MkSegLevel DistribM
-> KernelNest
-> Body Kernels
-> DistribM (Stm Kernels, KernelsStms)
forall (m :: * -> *).
(MonadFreshNames m, LocalScope Kernels m) =>
MkSegLevel m
-> KernelNest -> Body Kernels -> m (Stm Kernels, KernelsStms)
constructKernel MkSegLevel DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel m
segThreadCapped KernelNest
nest' (Body Kernels -> DistribM (Stm Kernels, KernelsStms))
-> Body Kernels -> DistribM (Stm Kernels, KernelsStms)
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
sequentialised_lam

        let outer_pat :: Pattern Kernels
outer_pat = LoopNesting -> Pattern Kernels
loopNestingPattern (LoopNesting -> Pattern Kernels) -> LoopNesting -> Pattern Kernels
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
        (KernelsStms
nestw_bndsKernelsStms -> KernelsStms -> KernelsStms
forall a. Semigroup a => a -> a -> a
<>) (KernelsStms -> KernelsStms)
-> DistribM KernelsStms -> DistribM KernelsStms
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          KernelNest
-> KernelPath
-> (KernelPath -> DistribM KernelsStms)
-> (KernelPath -> DistribM KernelsStms)
-> Pattern SOACS
-> Lambda
-> DistribM KernelsStms
onMap' KernelNest
nest' KernelPath
path
          (DistribM KernelsStms -> KernelPath -> DistribM KernelsStms
forall a b. a -> b -> a
const (DistribM KernelsStms -> KernelPath -> DistribM KernelsStms)
-> DistribM KernelsStms -> KernelPath -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ KernelsStms -> DistribM KernelsStms
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelsStms -> DistribM KernelsStms)
-> KernelsStms -> DistribM KernelsStms
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> KernelsStms
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
sequentialised_kernel)
          KernelPath -> DistribM KernelsStms
exploitInnerParallelism
          Pattern SOACS
Pattern Kernels
outer_pat Lambda
lam'

      KernelsStms -> DistNestT DistribM ()
forall (m :: * -> *). Monad m => KernelsStms -> DistNestT m ()
addKernel KernelsStms
stms
      DistAcc -> DistNestT DistribM DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc'