{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

-- | Kernel extraction.
--
-- In the following, I will use the term "width" to denote the amount
-- of immediate parallelism in a map - that is, the outer size of the
-- array(s) being used as input.
--
-- = Basic Idea
--
-- If we have:
--
-- @
--   map
--     map(f)
--     bnds_a...
--     map(g)
-- @
--
-- Then we want to distribute to:
--
-- @
--   map
--     map(f)
--   map
--     bnds_a
--   map
--     map(g)
-- @
--
-- But for now only if
--
--  (0) it can be done without creating irregular arrays.
--      Specifically, the size of the arrays created by @map(f)@, by
--      @map(g)@ and whatever is created by @bnds_a@ that is also used
--      in @map(g)@, must be invariant to the outermost loop.
--
--  (1) the maps are _balanced_.  That is, the functions @f@ and @g@
--      must do the same amount of work for every iteration.
--
-- The advantage is that the map-nests containing @map(f)@ and
-- @map(g)@ can now be trivially flattened at no cost, thus exposing
-- more parallelism.  Note that the @bnds_a@ map constitutes array
-- expansion, which requires additional storage.
--
-- = Distributing Sequential Loops
--
-- As a starting point, sequential loops are treated like scalar
-- expressions.  That is, not distributed.  However, sometimes it can
-- be worthwhile to distribute if they contain a map:
--
-- @
--   map
--     loop
--       map
--     map
-- @
--
-- If we distribute the loop and interchange the outer map into the
-- loop, we get this:
--
-- @
--   loop
--     map
--       map
--   map
--     map
-- @
--
-- Now more parallelism may be available.
--
-- = Unbalanced Maps
--
-- Unbalanced maps will as a rule be sequentialised, but sometimes,
-- there is another way.  Assume we find this:
--
-- @
--   map
--     map(f)
--       map(g)
--     map
-- @
--
-- Presume that @map(f)@ is unbalanced.  By the simple rule above, we
-- would then fully sequentialise it, resulting in this:
--
-- @
--   map
--     loop
--   map
--     map
-- @
--
-- == Balancing by Loop Interchange
--
-- The above is not ideal, as we cannot flatten the @map-loop@ nest,
-- and we are thus limited in the amount of parallelism available.
--
-- But assume now that the width of @map(g)@ is invariant to the outer
-- loop.  Then if possible, we can interchange @map(f)@ and @map(g)@,
-- sequentialise @map(f)@ and distribute, interchanging the outer
-- parallel loop into the sequential loop:
--
-- @
--   loop(f)
--     map
--       map(g)
--   map
--     map
-- @
--
-- After flattening the two nests we can obtain more parallelism.
--
-- When distributing a map, we also need to distribute everything that
-- the map depends on - possibly as its own map.  When distributing a
-- set of scalar bindings, we will need to know which of the binding
-- results are used afterwards.  Hence, we will need to compute usage
-- information.
--
-- = Redomap
--
-- Redomap can be handled much like map.  Distributed loops are
-- distributed as maps, with the parameters corresponding to the
-- neutral elements added to their bodies.  The remaining loop will
-- remain a redomap.  Example:
--
-- @
-- redomap(op,
--         fn (v) =>
--           map(f)
--           map(g),
--         e,a)
-- @
--
-- distributes to
--
-- @
-- let b = map(fn v =>
--               let acc = e
--               map(f),
--               a)
-- redomap(op,
--         fn (v,dist) =>
--           map(g),
--         e,a,b)
-- @
--
-- Note that there may be further kernel extraction opportunities
-- inside the @map(f)@.  The downside of this approach is that the
-- intermediate array (@b@ above) must be written to main memory.  An
-- often better approach is to just turn the entire @redomap@ into a
-- single kernel.
module Futhark.Pass.ExtractKernels (extractKernels) where

import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Data.Bifunctor (first)
import Data.Maybe
import qualified Futhark.IR.GPU as Out
import Futhark.IR.GPU.Kernel
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyStms)
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Intragroup
import Futhark.Pass.ExtractKernels.StreamKernel
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename
import Futhark.Util.Log
import Prelude hiding (log)

-- | Transform a program using SOACs to a program using explicit
-- kernels, using the kernel extraction transformation.
extractKernels :: Pass SOACS Out.GPU
extractKernels :: Pass SOACS GPU
extractKernels =
  Pass :: forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    { passName :: [Char]
passName = [Char]
"extract kernels",
      passDescription :: [Char]
passDescription = [Char]
"Perform kernel extraction",
      passFunction :: Prog SOACS -> PassM (Prog GPU)
passFunction = Prog SOACS -> PassM (Prog GPU)
transformProg
    }

transformProg :: Prog SOACS -> PassM (Prog Out.GPU)
transformProg :: Prog SOACS -> PassM (Prog GPU)
transformProg (Prog Stms SOACS
consts [FunDef SOACS]
funs) = do
  Stms GPU
consts' <- DistribM (Stms GPU) -> PassM (Stms GPU)
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM (Stms GPU) -> PassM (Stms GPU))
-> DistribM (Stms GPU) -> PassM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
forall a. Monoid a => a
mempty ([Stm SOACS] -> DistribM (Stms GPU))
-> [Stm SOACS] -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
consts
  [FunDef GPU]
funs' <- (FunDef SOACS -> PassM (FunDef GPU))
-> [FunDef SOACS] -> PassM [FunDef GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope GPU -> FunDef SOACS -> PassM (FunDef GPU)
forall (m :: * -> *).
(MonadFreshNames m, MonadLogger m) =>
Scope GPU -> FunDef SOACS -> m (FunDef GPU)
transformFunDef (Scope GPU -> FunDef SOACS -> PassM (FunDef GPU))
-> Scope GPU -> FunDef SOACS -> PassM (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
consts') [FunDef SOACS]
funs
  Prog GPU -> PassM (Prog GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog GPU -> PassM (Prog GPU)) -> Prog GPU -> PassM (Prog GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [FunDef GPU] -> Prog GPU
forall rep. Stms rep -> [FunDef rep] -> Prog rep
Prog Stms GPU
consts' [FunDef GPU]
funs'

-- In order to generate more stable threshold names, we keep track of
-- the numbers used for thresholds separately from the ordinary name
-- source,
data State = State
  { State -> VNameSource
stateNameSource :: VNameSource,
    State -> Int
stateThresholdCounter :: Int
  }

newtype DistribM a = DistribM (RWS (Scope Out.GPU) Log State a)
  deriving
    ( (forall a b. (a -> b) -> DistribM a -> DistribM b)
-> (forall a b. a -> DistribM b -> DistribM a) -> Functor DistribM
forall a b. a -> DistribM b -> DistribM a
forall a b. (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DistribM b -> DistribM a
$c<$ :: forall a b. a -> DistribM b -> DistribM a
fmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
$cfmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
Functor,
      Functor DistribM
Functor DistribM
-> (forall a. a -> DistribM a)
-> (forall a b. DistribM (a -> b) -> DistribM a -> DistribM b)
-> (forall a b c.
    (a -> b -> c) -> DistribM a -> DistribM b -> DistribM c)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM a)
-> Applicative DistribM
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. DistribM a -> DistribM b -> DistribM a
$c<* :: forall a b. DistribM a -> DistribM b -> DistribM a
*> :: forall a b. DistribM a -> DistribM b -> DistribM b
$c*> :: forall a b. DistribM a -> DistribM b -> DistribM b
liftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
$c<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
pure :: forall a. a -> DistribM a
$cpure :: forall a. a -> DistribM a
Applicative,
      Applicative DistribM
Applicative DistribM
-> (forall a b. DistribM a -> (a -> DistribM b) -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a. a -> DistribM a)
-> Monad DistribM
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> DistribM a
$creturn :: forall a. a -> DistribM a
>> :: forall a b. DistribM a -> DistribM b -> DistribM b
$c>> :: forall a b. DistribM a -> DistribM b -> DistribM b
>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
$c>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
Monad,
      HasScope Out.GPU,
      LocalScope Out.GPU,
      MonadState State,
      Monad DistribM
Applicative DistribM
Applicative DistribM
-> Monad DistribM
-> (forall a. ToLog a => a -> DistribM ())
-> (Log -> DistribM ())
-> MonadLogger DistribM
Log -> DistribM ()
forall a. ToLog a => a -> DistribM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> (forall a. ToLog a => a -> m ())
-> (Log -> m ())
-> MonadLogger m
addLog :: Log -> DistribM ()
$caddLog :: Log -> DistribM ()
logMsg :: forall a. ToLog a => a -> DistribM ()
$clogMsg :: forall a. ToLog a => a -> DistribM ()
MonadLogger
    )

instance MonadFreshNames DistribM where
  getNameSource :: DistribM VNameSource
getNameSource = (State -> VNameSource) -> DistribM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> DistribM ()
putNameSource VNameSource
src = (State -> State) -> DistribM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> DistribM ())
-> (State -> State) -> DistribM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

runDistribM ::
  (MonadLogger m, MonadFreshNames m) =>
  DistribM a ->
  m a
runDistribM :: forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM RWS (Scope GPU) Log State a
m) = do
  (a
x, Log
msgs) <- (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Log), VNameSource)) -> m (a, Log))
-> (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let (a
x, State
s, Log
msgs) = RWS (Scope GPU) Log State a
-> Scope GPU -> State -> (a, State, Log)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Scope GPU) Log State a
m Scope GPU
forall a. Monoid a => a
mempty (VNameSource -> Int -> State
State VNameSource
src Int
0)
     in ((a
x, Log
msgs), State -> VNameSource
stateNameSource State
s)
  Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
msgs
  a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

transformFunDef ::
  (MonadFreshNames m, MonadLogger m) =>
  Scope Out.GPU ->
  FunDef SOACS ->
  m (Out.FunDef Out.GPU)
transformFunDef :: forall (m :: * -> *).
(MonadFreshNames m, MonadLogger m) =>
Scope GPU -> FunDef SOACS -> m (FunDef GPU)
transformFunDef Scope GPU
scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params BodyT SOACS
body) = DistribM (FunDef GPU) -> m (FunDef GPU)
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM (FunDef GPU) -> m (FunDef GPU))
-> DistribM (FunDef GPU) -> m (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ do
  Body GPU
body' <-
    Scope GPU -> DistribM (Body GPU) -> DistribM (Body GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope GPU
scope Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (DistribM (Body GPU) -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$
      KernelPath -> BodyT SOACS -> DistribM (Body GPU)
transformBody KernelPath
forall a. Monoid a => a
mempty BodyT SOACS
body
  FunDef GPU -> DistribM (FunDef GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef GPU -> DistribM (FunDef GPU))
-> FunDef GPU -> DistribM (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType GPU]
-> [FParam GPU]
-> Body GPU
-> FunDef GPU
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> BodyT rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
[RetType GPU]
rettype [FParam SOACS]
[FParam GPU]
params Body GPU
body'

type GPUStms = Stms Out.GPU

transformBody :: KernelPath -> Body -> DistribM (Out.Body Out.GPU)
transformBody :: KernelPath -> BodyT SOACS -> DistribM (Body GPU)
transformBody KernelPath
path BodyT SOACS
body = do
  Stms GPU
bnds <- KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> [Stm SOACS] -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms BodyT SOACS
body
  Body GPU -> DistribM (Body GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body GPU -> DistribM (Body GPU))
-> Body GPU -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
bnds (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT SOACS
body

transformStms :: KernelPath -> [Stm] -> DistribM GPUStms
transformStms :: KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
_ [] =
  Stms GPU -> DistribM (Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms GPU
forall a. Monoid a => a
mempty
transformStms KernelPath
path (Stm SOACS
bnd : [Stm SOACS]
bnds) =
  Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm Stm SOACS
bnd DistribM (Maybe (Stms SOACS))
-> (Maybe (Stms SOACS) -> DistribM (Stms GPU))
-> DistribM (Stms GPU)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Stms SOACS)
Nothing -> do
      Stms GPU
bnd' <- KernelPath -> Stm SOACS -> DistribM (Stms GPU)
transformStm KernelPath
path Stm SOACS
bnd
      Stms GPU -> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPU
bnd' (DistribM (Stms GPU) -> DistribM (Stms GPU))
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
        (Stms GPU
bnd' Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<>) (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path [Stm SOACS]
bnds
    Just Stms SOACS
bnds' ->
      KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> [Stm SOACS] -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
bnds' [Stm SOACS] -> [Stm SOACS] -> [Stm SOACS]
forall a. Semigroup a => a -> a -> a
<> [Stm SOACS]
bnds

unbalancedLambda :: Lambda -> Bool
unbalancedLambda :: Lambda -> Bool
unbalancedLambda Lambda
orig_lam =
  Names -> BodyT SOACS -> Bool
forall {rep} {rep}.
(Op rep ~ SOAC rep) =>
Names -> BodyT rep -> Bool
unbalancedBody ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda
orig_lam) (BodyT SOACS -> Bool) -> BodyT SOACS -> Bool
forall a b. (a -> b) -> a -> b
$
    Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
orig_lam
  where
    subExpBound :: SubExp -> Names -> Bool
subExpBound (Var VName
i) Names
bound = VName
i VName -> Names -> Bool
`nameIn` Names
bound
    subExpBound (Constant PrimValue
_) Names
_ = Bool
False

    unbalancedBody :: Names -> BodyT rep -> Bool
unbalancedBody Names
bound BodyT rep
body =
      (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Names -> ExpT rep -> Bool
unbalancedStm (Names
bound Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> BodyT rep -> Names
forall rep. Body rep -> Names
boundInBody BodyT rep
body) (ExpT rep -> Bool) -> (Stm rep -> ExpT rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp) (Seq (Stm rep) -> Bool) -> Seq (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$
        BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
body

    -- XXX - our notion of balancing is probably still too naive.
    unbalancedStm :: Names -> ExpT rep -> Bool
unbalancedStm Names
bound (Op (Stream SubExp
w [VName]
_ StreamForm rep
_ Result
_ Lambda rep
_)) =
      SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
    unbalancedStm Names
bound (Op (Screma SubExp
w [VName]
_ ScremaForm rep
_)) =
      SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
    unbalancedStm Names
_ Op {} =
      Bool
False
    unbalancedStm Names
_ DoLoop {} = Bool
False
    unbalancedStm Names
bound (WithAcc [(Shape, [VName], Maybe (Lambda rep, Result))]
_ Lambda rep
lam) =
      Names -> BodyT rep -> Bool
unbalancedBody Names
bound (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam)
    unbalancedStm Names
bound (If SubExp
cond BodyT rep
tbranch BodyT rep
fbranch IfDec (BranchType rep)
_) =
      SubExp
cond SubExp -> Names -> Bool
`subExpBound` Names
bound
        Bool -> Bool -> Bool
&& (Names -> BodyT rep -> Bool
unbalancedBody Names
bound BodyT rep
tbranch Bool -> Bool -> Bool
|| Names -> BodyT rep -> Bool
unbalancedBody Names
bound BodyT rep
fbranch)
    unbalancedStm Names
_ (BasicOp BasicOp
_) =
      Bool
False
    unbalancedStm Names
_ (Apply Name
fname [(SubExp, Diet)]
_ [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) =
      Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> Bool
isBuiltInFunction Name
fname

sequentialisedUnbalancedStm :: Stm -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm :: Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
_ (Op soac :: Op SOACS
soac@(Screma SubExp
_ [VName]
_ ScremaForm SOACS
form)))
  | Just ([Reduce SOACS]
_, Lambda
lam2) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
    Lambda -> Bool
unbalancedLambda Lambda
lam2,
    Lambda -> Bool
lambdaContainsParallelism Lambda
lam2 = do
    Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
    Stms SOACS -> Maybe (Stms SOACS)
forall a. a -> Maybe a
Just (Stms SOACS -> Maybe (Stms SOACS))
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Maybe (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Maybe (Stms SOACS))
-> DistribM ((), Stms SOACS) -> DistribM (Maybe (Stms SOACS))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT (Pattern (Rep (BinderT SOACS DistribM))
-> SOAC (Rep (BinderT SOACS DistribM)) -> BinderT SOACS DistribM ()
forall (m :: * -> *).
Transformer m =>
Pattern (Rep m) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pattern (Rep (BinderT SOACS DistribM))
Pattern SOACS
pat Op SOACS
SOAC (Rep (BinderT SOACS DistribM))
soac) Scope SOACS
types
sequentialisedUnbalancedStm Stm SOACS
_ =
  Maybe (Stms SOACS) -> DistribM (Maybe (Stms SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stms SOACS)
forall a. Maybe a
Nothing

cmpSizeLe ::
  String ->
  Out.SizeClass ->
  [SubExp] ->
  DistribM ((SubExp, Name), Out.Stms Out.GPU)
cmpSizeLe :: [Char]
-> SizeClass -> Result -> DistribM ((SubExp, Name), Stms GPU)
cmpSizeLe [Char]
desc SizeClass
size_class Result
to_what = do
  Int
x <- (State -> Int) -> DistribM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Int
stateThresholdCounter
  (State -> State) -> DistribM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> DistribM ())
-> (State -> State) -> DistribM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateThresholdCounter :: Int
stateThresholdCounter = Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}
  let size_key :: Name
size_key = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
desc [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
x
  Binder GPU (SubExp, Name) -> DistribM ((SubExp, Name), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder GPU (SubExp, Name) -> DistribM ((SubExp, Name), Stms GPU))
-> Binder GPU (SubExp, Name) -> DistribM ((SubExp, Name), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    SubExp
to_what' <-
      [Char]
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"comparatee"
        (ExpT GPU -> BinderT GPU (State VNameSource) SubExp)
-> BinderT GPU (State VNameSource) (ExpT GPU)
-> BinderT GPU (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) Result
to_what
    SubExp
cmp_res <- [Char]
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
desc (Exp (Rep (BinderT GPU (State VNameSource)))
 -> BinderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp GPU (SOAC GPU))
-> SizeOp -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
size_key SizeClass
size_class SubExp
to_what'
    (SubExp, Name) -> Binder GPU (SubExp, Name)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
cmp_res, Name
size_key)

kernelAlternatives ::
  (MonadFreshNames m, HasScope Out.GPU m) =>
  Out.Pattern Out.GPU ->
  Out.Body Out.GPU ->
  [(SubExp, Out.Body Out.GPU)] ->
  m (Out.Stms Out.GPU)
kernelAlternatives :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pattern GPU
pat Body GPU
default_body [] = Binder GPU () -> m (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder GPU () -> m (Stms GPU)) -> Binder GPU () -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
  Result
ses <- Body (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) Result
forall (m :: * -> *). MonadBinder m => Body (Rep m) -> m Result
bodyBind Body (Rep (BinderT GPU (State VNameSource)))
Body GPU
default_body
  [(VName, SubExp)]
-> ((VName, SubExp) -> Binder GPU ()) -> Binder GPU ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern GPU
pat) Result
ses) (((VName, SubExp) -> Binder GPU ()) -> Binder GPU ())
-> ((VName, SubExp) -> Binder GPU ()) -> Binder GPU ()
forall a b. (a -> b) -> a -> b
$ \(VName
name, SubExp
se) ->
    [VName]
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ())
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
kernelAlternatives Pattern GPU
pat Body GPU
default_body ((SubExp
cond, Body GPU
alt) : [(SubExp, Body GPU)]
alts) = Binder GPU () -> m (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder GPU () -> m (Stms GPU)) -> Binder GPU () -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
  PatternT Type
alts_pat <- ([PatElemT Type] -> PatternT Type)
-> BinderT GPU (State VNameSource) [PatElemT Type]
-> BinderT GPU (State VNameSource) (PatternT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern []) (BinderT GPU (State VNameSource) [PatElemT Type]
 -> BinderT GPU (State VNameSource) (PatternT Type))
-> BinderT GPU (State VNameSource) [PatElemT Type]
-> BinderT GPU (State VNameSource) (PatternT Type)
forall a b. (a -> b) -> a -> b
$
    [PatElemT Type]
-> (PatElemT Type
    -> BinderT GPU (State VNameSource) (PatElemT Type))
-> BinderT GPU (State VNameSource) [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern GPU
pat) ((PatElemT Type -> BinderT GPU (State VNameSource) (PatElemT Type))
 -> BinderT GPU (State VNameSource) [PatElemT Type])
-> (PatElemT Type
    -> BinderT GPU (State VNameSource) (PatElemT Type))
-> BinderT GPU (State VNameSource) [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \PatElemT Type
pe -> do
      VName
name <- [Char] -> BinderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> BinderT GPU (State VNameSource) VName)
-> [Char] -> BinderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
pe
      PatElemT Type -> BinderT GPU (State VNameSource) (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT Type
pe {patElemName :: VName
patElemName = VName
name}

  Stms GPU
alt_stms <- Pattern GPU
-> Body GPU
-> [(SubExp, Body GPU)]
-> BinderT GPU (State VNameSource) (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives PatternT Type
Pattern GPU
alts_pat Body GPU
default_body [(SubExp, Body GPU)]
alts
  let alt_body :: Body GPU
alt_body = Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
alt_stms (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames PatternT Type
alts_pat

  Pattern (Rep (BinderT GPU (State VNameSource)))
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern (Rep (BinderT GPU (State VNameSource)))
Pattern GPU
pat (Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ())
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ()
forall a b. (a -> b) -> a -> b
$
    SubExp
-> Body GPU -> Body GPU -> IfDec (BranchType GPU) -> ExpT GPU
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond Body GPU
alt Body GPU
alt_body (IfDec (BranchType GPU) -> ExpT GPU)
-> IfDec (BranchType GPU) -> ExpT GPU
forall a b. (a -> b) -> a -> b
$
      [ExtType] -> IfSort -> IfDec ExtType
forall rt. [rt] -> IfSort -> IfDec rt
IfDec ([Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase (ShapeBase ExtSize) u]
staticShapes (PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern GPU
pat)) IfSort
IfEquiv

transformLambda :: KernelPath -> Lambda -> DistribM (Out.Lambda Out.GPU)
transformLambda :: KernelPath -> Lambda -> DistribM (Lambda GPU)
transformLambda KernelPath
path (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
ret) =
  [LParam GPU] -> Body GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam SOACS]
[LParam GPU]
params
    (Body GPU -> [Type] -> Lambda GPU)
-> DistribM (Body GPU) -> DistribM ([Type] -> Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPU -> DistribM (Body GPU) -> DistribM (Body GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
[LParam SOACS]
params) (KernelPath -> BodyT SOACS -> DistribM (Body GPU)
transformBody KernelPath
path BodyT SOACS
body)
    DistribM ([Type] -> Lambda GPU)
-> DistribM [Type] -> DistribM (Lambda GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> DistribM [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret

transformStm :: KernelPath -> Stm -> DistribM GPUStms
transformStm :: KernelPath -> Stm SOACS -> DistribM (Stms GPU)
transformStm KernelPath
_ Stm SOACS
stm
  | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) =
    Binder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder GPU () -> DistribM (Stms GPU))
-> Binder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Binder GPU ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
stm
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac))
  | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
    KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> DistribM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall rep. Certificates -> Stm rep -> Stm rep
certify (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
      (Stms SOACS -> DistribM (Stms GPU))
-> DistribM (Stms SOACS) -> DistribM (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Binder SOACS () -> DistribM (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Pattern (Rep (BinderT SOACS (State VNameSource)))
-> SOAC (Rep (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pattern (Rep m) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pattern (Rep (BinderT SOACS (State VNameSource)))
Pattern SOACS
pat Op SOACS
SOAC (Rep (BinderT SOACS (State VNameSource)))
soac)
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (If SubExp
c BodyT SOACS
tb BodyT SOACS
fb IfDec (BranchType SOACS)
rt)) = do
  Body GPU
tb' <- KernelPath -> BodyT SOACS -> DistribM (Body GPU)
transformBody KernelPath
path BodyT SOACS
tb
  Body GPU
fb' <- KernelPath -> BodyT SOACS -> DistribM (Body GPU)
transformBody KernelPath
path BodyT SOACS
fb
  Stms GPU -> DistribM (Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPU -> DistribM (Stms GPU))
-> Stms GPU -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU) -> Stm GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Pattern GPU -> StmAux (ExpDec GPU) -> ExpT GPU -> Stm GPU
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern GPU
pat StmAux (ExpDec SOACS)
StmAux (ExpDec GPU)
aux (ExpT GPU -> Stm GPU) -> ExpT GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body GPU -> Body GPU -> IfDec (BranchType GPU) -> ExpT GPU
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
c Body GPU
tb' Body GPU
fb' IfDec (BranchType SOACS)
IfDec (BranchType GPU)
rt
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (WithAcc [(Shape, [VName], Maybe (Lambda, Result))]
inputs Lambda
lam)) =
  Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (ExpT GPU -> Stm GPU) -> ExpT GPU -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern GPU -> StmAux (ExpDec GPU) -> ExpT GPU -> Stm GPU
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern GPU
pat StmAux (ExpDec SOACS)
StmAux (ExpDec GPU)
aux
    (ExpT GPU -> Stms GPU)
-> DistribM (ExpT GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(Shape, [VName], Maybe (Lambda GPU, Result))]
-> Lambda GPU -> ExpT GPU
forall rep.
[(Shape, [VName], Maybe (Lambda rep, Result))]
-> Lambda rep -> ExpT rep
WithAcc (((Shape, [VName], Maybe (Lambda, Result))
 -> (Shape, [VName], Maybe (Lambda GPU, Result)))
-> [(Shape, [VName], Maybe (Lambda, Result))]
-> [(Shape, [VName], Maybe (Lambda GPU, Result))]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, [VName], Maybe (Lambda, Result))
-> (Shape, [VName], Maybe (Lambda GPU, Result))
forall {f :: * -> *} {p :: * -> * -> *} {a} {b} {c}.
(Functor f, Bifunctor p) =>
(a, b, f (p Lambda c)) -> (a, b, f (p (Lambda GPU) c))
transformInput [(Shape, [VName], Maybe (Lambda, Result))]
inputs) (Lambda GPU -> ExpT GPU)
-> DistribM (Lambda GPU) -> DistribM (ExpT GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> Lambda -> DistribM (Lambda GPU)
transformLambda KernelPath
path Lambda
lam)
  where
    transformInput :: (a, b, f (p Lambda c)) -> (a, b, f (p (Lambda GPU) c))
transformInput (a
shape, b
arrs, f (p Lambda c)
op) =
      (a
shape, b
arrs, (p Lambda c -> p (Lambda GPU) c)
-> f (p Lambda c) -> f (p (Lambda GPU) c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Lambda -> Lambda GPU) -> p Lambda c -> p (Lambda GPU) c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Lambda -> Lambda GPU
soacsLambdaToGPU) f (p Lambda c)
op)
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
body)) =
  Scope GPU -> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
    ( Scope SOACS -> Scope GPU
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (LoopForm SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm SOACS
form)
        Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
mergeparams
    )
    (DistribM (Stms GPU) -> DistribM (Stms GPU))
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Body GPU -> Stm GPU) -> Body GPU -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern GPU -> StmAux (ExpDec GPU) -> ExpT GPU -> Stm GPU
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern SOACS
Pattern GPU
pat StmAux (ExpDec SOACS)
StmAux (ExpDec GPU)
aux (ExpT GPU -> Stm GPU)
-> (Body GPU -> ExpT GPU) -> Body GPU -> Stm GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(FParam GPU, SubExp)]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> ExpT GPU
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam SOACS, SubExp)]
[(FParam GPU, SubExp)]
ctx [(FParam SOACS, SubExp)]
[(FParam GPU, SubExp)]
val LoopForm GPU
form' (Body GPU -> Stms GPU)
-> DistribM (Body GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> BodyT SOACS -> DistribM (Body GPU)
transformBody KernelPath
path BodyT SOACS
body
  where
    mergeparams :: [Param DeclType]
mergeparams = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val
    form' :: LoopForm GPU
form' = case LoopForm SOACS
form of
      WhileLoop VName
cond ->
        VName -> LoopForm GPU
forall rep. VName -> LoopForm rep
WhileLoop VName
cond
      ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
ps ->
        VName -> IntType -> SubExp -> [(LParam GPU, VName)] -> LoopForm GPU
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam GPU, VName)]
ps
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
  | Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form =
    KernelPath -> MapLoop -> DistribM (Stms GPU)
onMap KernelPath
path (MapLoop -> DistribM (Stms GPU)) -> MapLoop -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Pattern SOACS
-> StmAux () -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern SOACS
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w Lambda
lam [VName]
arrs
transformStm KernelPath
path (Let Pattern SOACS
res_pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
  | Just [Scan SOACS]
scans <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form,
    Scan Lambda
scan_lam Result
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Bindable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans,
    Just BinderT SOACS DistribM ()
do_iswim <- Pattern SOACS
-> SubExp
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBinder m, Rep m ~ SOACS) =>
Pattern SOACS
-> SubExp -> Lambda -> [(SubExp, VName)] -> Maybe (m ())
iswim Pattern SOACS
res_pat SubExp
w Lambda
scan_lam ([(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
    Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
    KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> DistribM (Stms GPU))
-> DistribM ((), Stms SOACS) -> DistribM (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT (Certificates
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs BinderT SOACS DistribM ()
do_iswim) Scope SOACS
types
  | Just ([Scan SOACS]
scans, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = Binder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder GPU () -> DistribM (Stms GPU))
-> Binder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    [SegBinOp GPU]
scan_ops <- [Scan SOACS]
-> (Scan SOACS -> BinderT GPU (State VNameSource) (SegBinOp GPU))
-> BinderT GPU (State VNameSource) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan SOACS]
scans ((Scan SOACS -> BinderT GPU (State VNameSource) (SegBinOp GPU))
 -> BinderT GPU (State VNameSource) [SegBinOp GPU])
-> (Scan SOACS -> BinderT GPU (State VNameSource) (SegBinOp GPU))
-> BinderT GPU (State VNameSource) [SegBinOp GPU]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda
scan_lam Result
nes) -> do
      (Lambda
scan_lam', Result
nes', Shape
shape) <- Lambda
-> Result
-> BinderT GPU (State VNameSource) (Lambda, Result, Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
scan_lam Result
nes
      let scan_lam'' :: Lambda GPU
scan_lam'' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
scan_lam'
      SegBinOp GPU -> BinderT GPU (State VNameSource) (SegBinOp GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegBinOp GPU -> BinderT GPU (State VNameSource) (SegBinOp GPU))
-> SegBinOp GPU -> BinderT GPU (State VNameSource) (SegBinOp GPU)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda GPU -> Result -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> Result -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scan_lam'' Result
nes' Shape
shape
    let map_lam_sequential :: Lambda GPU
map_lam_sequential = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
map_lam
    SegLevel
lvl <- MkSegLevel GPU (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] [Char]
"segscan" (ThreadRecommendation
 -> BinderT GPU (State VNameSource) (SegOpLevel GPU))
-> ThreadRecommendation
-> BinderT GPU (State VNameSource) (SegOpLevel GPU)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
    Stms GPU -> Binder GPU ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> Binder GPU ())
-> (Stms GPU -> Stms GPU) -> Stms GPU -> Binder GPU ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm GPU -> Stm GPU) -> Stms GPU -> Stms GPU
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm GPU -> Stm GPU
forall rep. Certificates -> Stm rep -> Stm rep
certify Certificates
cs)
      (Stms GPU -> Binder GPU ())
-> BinderT GPU (State VNameSource) (Stms GPU) -> Binder GPU ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pattern GPU
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT GPU (State VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pattern rep
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel GPU
SegLevel
lvl Pattern SOACS
Pattern GPU
res_pat SubExp
w [SegBinOp GPU]
scan_ops Lambda GPU
map_lam_sequential [VName]
arrs [] []
transformStm KernelPath
path (Let Pattern SOACS
res_pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
  | Just [Reduce Commutativity
comm Lambda
red_fun Result
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    let comm' :: Commutativity
comm'
          | Lambda -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda
red_fun = Commutativity
Commutative
          | Bool
otherwise = Commutativity
comm,
    Just BinderT SOACS DistribM ()
do_irwim <- Pattern SOACS
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBinder m, Rep m ~ SOACS) =>
Pattern SOACS
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern SOACS
res_pat SubExp
w Commutativity
comm' Lambda
red_fun ([(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
    Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
    (SymbolTable (Wise SOACS)
_, Stms SOACS
bnds) <- ((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
-> (SymbolTable (Wise SOACS), Stms SOACS)
forall a b. (a, b) -> a
fst (((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
 -> (SymbolTable (Wise SOACS), Stms SOACS))
-> DistribM ((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
-> DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
-> Scope SOACS
-> DistribM ((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT (Stms SOACS
-> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms (Stms SOACS
 -> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS))
-> BinderT SOACS DistribM (Stms SOACS)
-> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> BinderT SOACS DistribM (Stms (Rep (BinderT SOACS DistribM)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Rep m))
collectStms_ (StmAux () -> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux BinderT SOACS DistribM ()
do_irwim)) Scope SOACS
types
    KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> [Stm SOACS] -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
bnds
transformStm KernelPath
path (Let Pattern SOACS
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
  | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
    let paralleliseOuter :: DistribM (Stms GPU)
paralleliseOuter = Binder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder GPU () -> DistribM (Stms GPU))
-> Binder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
          [SegBinOp GPU]
red_ops <- [Reduce SOACS]
-> (Reduce SOACS -> BinderT GPU (State VNameSource) (SegBinOp GPU))
-> BinderT GPU (State VNameSource) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds ((Reduce SOACS -> BinderT GPU (State VNameSource) (SegBinOp GPU))
 -> BinderT GPU (State VNameSource) [SegBinOp GPU])
-> (Reduce SOACS -> BinderT GPU (State VNameSource) (SegBinOp GPU))
-> BinderT GPU (State VNameSource) [SegBinOp GPU]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda
red_lam Result
nes) -> do
            (Lambda
red_lam', Result
nes', Shape
shape) <- Lambda
-> Result
-> BinderT GPU (State VNameSource) (Lambda, Result, Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
red_lam Result
nes
            let comm' :: Commutativity
comm'
                  | Lambda -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda
red_lam' = Commutativity
Commutative
                  | Bool
otherwise = Commutativity
comm
                red_lam'' :: Lambda GPU
red_lam'' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
red_lam'
            SegBinOp GPU -> BinderT GPU (State VNameSource) (SegBinOp GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegBinOp GPU -> BinderT GPU (State VNameSource) (SegBinOp GPU))
-> SegBinOp GPU -> BinderT GPU (State VNameSource) (SegBinOp GPU)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda GPU -> Result -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> Result -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm' Lambda GPU
red_lam'' Result
nes' Shape
shape
          let map_lam_sequential :: Lambda GPU
map_lam_sequential = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
map_lam
          SegLevel
lvl <- MkSegLevel GPU (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] [Char]
"segred" (ThreadRecommendation
 -> BinderT GPU (State VNameSource) (SegOpLevel GPU))
-> ThreadRecommendation
-> BinderT GPU (State VNameSource) (SegOpLevel GPU)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
          Stms GPU -> Binder GPU ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> Binder GPU ())
-> (Stms GPU -> Stms GPU) -> Stms GPU -> Binder GPU ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm GPU -> Stm GPU) -> Stms GPU -> Stms GPU
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm GPU -> Stm GPU
forall rep. Certificates -> Stm rep -> Stm rep
certify Certificates
cs)
            (Stms GPU -> Binder GPU ())
-> BinderT GPU (State VNameSource) (Stms GPU) -> Binder GPU ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pattern GPU
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> BinderT GPU (State VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pattern rep
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel GPU
SegLevel
lvl Pattern SOACS
Pattern GPU
pat SubExp
w [SegBinOp GPU]
red_ops Lambda GPU
map_lam_sequential [VName]
arrs

        outerParallelBody :: DistribM (Body GPU)
outerParallelBody =
          Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
            (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistribM (Stms GPU)
paralleliseOuter DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))

        paralleliseInner :: KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' = do
          (Stm SOACS
mapstm, Stm SOACS
redstm) <-
            Pattern SOACS
-> (SubExp, Commutativity, Lambda, Lambda, Result, [VName])
-> DistribM (Stm SOACS, Stm SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, Bindable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pattern rep
-> (SubExp, Commutativity, LambdaT rep, LambdaT rep, Result,
    [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pattern SOACS
pat (SubExp
w, Commutativity
comm', Lambda
red_lam, Lambda
map_lam, Result
nes, [VName]
arrs)
          Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
          KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path' ([Stm SOACS] -> DistribM (Stms GPU))
-> (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> DistribM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> DistribM (Stms GPU))
-> (BinderT SOACS DistribM () -> DistribM (Stms SOACS))
-> BinderT SOACS DistribM ()
-> DistribM (Stms GPU)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (BinderT SOACS DistribM () -> Scope SOACS -> DistribM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BinderT rep m () -> Scope rep -> m (Stms rep)
`runBinderT_` Scope SOACS
types) (BinderT SOACS DistribM () -> DistribM (Stms GPU))
-> BinderT SOACS DistribM () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
            (SymbolTable (Wise SOACS)
_, Stms SOACS
stms) <-
              Stms SOACS
-> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms ([Stm SOACS] -> Stms SOACS
forall rep. [Stm rep] -> Stms rep
stmsFromList [Certificates -> Stm SOACS -> Stm SOACS
forall rep. Certificates -> Stm rep -> Stm rep
certify Certificates
cs Stm SOACS
mapstm, Certificates -> Stm SOACS -> Stm SOACS
forall rep. Certificates -> Stm rep -> Stm rep
certify Certificates
cs Stm SOACS
redstm])
            Stms (Rep (BinderT SOACS DistribM)) -> BinderT SOACS DistribM ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BinderT SOACS DistribM))
Stms SOACS
stms
          where
            comm' :: Commutativity
comm'
              | Lambda -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda
red_lam = Commutativity
Commutative
              | Bool
otherwise = Commutativity
comm
            (Reduce Commutativity
comm Lambda
red_lam Result
nes) = [Reduce SOACS] -> Reduce SOACS
forall rep. Bindable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce SOACS]
reds

        innerParallelBody :: KernelPath -> DistribM (Body GPU)
innerParallelBody KernelPath
path' =
          Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
            (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))

    if Bool -> Bool
not (Lambda -> Bool
lambdaContainsParallelism Lambda
map_lam)
      Bool -> Bool -> Bool
|| Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux
      then DistribM (Stms GPU)
paralleliseOuter
      else do
        ((SubExp
outer_suff, Name
outer_suff_key), Stms GPU
suff_stms) <-
          [Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
"suff_outer_redomap" [SubExp
w] KernelPath
path Maybe Int64
forall a. Maybe a
Nothing

        Body GPU
outer_stms <- DistribM (Body GPU)
outerParallelBody
        Body GPU
inner_stms <- KernelPath -> DistribM (Body GPU)
innerParallelBody ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)

        (Stms GPU
suff_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<>) (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern GPU
-> Body GPU -> [(SubExp, Body GPU)] -> DistribM (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pattern SOACS
Pattern GPU
pat Body GPU
inner_stms [(SubExp
outer_suff, Body GPU
outer_stms)]

-- Streams can be handled in two different ways - either we
-- sequentialise the body or we keep it parallel and distribute.
transformStm KernelPath
path (Let Pattern SOACS
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Stream SubExp
w [VName]
arrs Parallel {} [] Lambda
map_fun)))
  | Bool -> Bool
not (Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) = do
    -- No reduction part.  Remove the stream and leave the body
    -- parallel.  It will be distributed.
    Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
    KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
      (((), Stms SOACS) -> DistribM (Stms GPU))
-> DistribM ((), Stms SOACS) -> DistribM (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT (Certificates
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT SOACS DistribM () -> BinderT SOACS DistribM ())
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (BinderT SOACS DistribM))
-> SubExp
-> Result
-> LambdaT (Rep (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m)) =>
Pattern (Rep m)
-> SubExp -> Result -> LambdaT (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Rep (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w [] LambdaT (Rep (BinderT SOACS DistribM))
Lambda
map_fun [VName]
arrs) Scope SOACS
types
transformStm KernelPath
path (Let Pattern SOACS
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Stream SubExp
w [VName]
arrs (Parallel StreamOrd
o Commutativity
comm Lambda
red_fun) Result
nes Lambda
fold_fun)))
  | Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
    KernelPath -> DistribM (Stms GPU)
paralleliseOuter KernelPath
path
  | Bool
otherwise = do
    ((SubExp
outer_suff, Name
outer_suff_key), Stms GPU
suff_stms) <-
      [Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
"suff_outer_stream" [SubExp
w] KernelPath
path Maybe Int64
forall a. Maybe a
Nothing

    Body GPU
outer_stms <- KernelPath -> DistribM (Body GPU)
outerParallelBody ((Name
outer_suff_key, Bool
True) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)
    Body GPU
inner_stms <- KernelPath -> DistribM (Body GPU)
innerParallelBody ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)

    (Stms GPU
suff_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<>)
      (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern GPU
-> Body GPU -> [(SubExp, Body GPU)] -> DistribM (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pattern SOACS
Pattern GPU
pat Body GPU
inner_stms [(SubExp
outer_suff, Body GPU
outer_stms)]
  where
    paralleliseOuter :: KernelPath -> DistribM (Stms GPU)
paralleliseOuter KernelPath
path'
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda
red_fun = do
        -- Split into a chunked map and a reduction, with the latter
        -- further transformed.
        let fold_fun' :: Lambda GPU
fold_fun' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
fold_fun

        let ([PatElemT Type]
red_pat_elems, [PatElemT Type]
concat_pat_elems) =
              Int -> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. Int -> [a] -> ([a], [a])
splitAt (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) ([PatElemT Type] -> ([PatElemT Type], [PatElemT Type]))
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
Pattern SOACS
pat
            red_pat :: PatternT Type
red_pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type]
red_pat_elems

        ((SubExp
num_threads, [VName]
red_results), Stms GPU
stms) <-
          MkSegLevel GPU DistribM
-> [[Char]]
-> [PatElem GPU]
-> SubExp
-> Commutativity
-> Lambda GPU
-> Result
-> [VName]
-> DistribM ((SubExp, [VName]), Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
MkSegLevel GPU m
-> [[Char]]
-> [PatElem GPU]
-> SubExp
-> Commutativity
-> Lambda GPU
-> Result
-> [VName]
-> m ((SubExp, [VName]), Stms GPU)
streamMap
            MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped
            ((PatElemT Type -> [Char]) -> [PatElemT Type] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [Char]
baseString (VName -> [Char])
-> (PatElemT Type -> VName) -> PatElemT Type -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
red_pat_elems)
            [PatElemT Type]
[PatElem GPU]
concat_pat_elems
            SubExp
w
            Commutativity
Noncommutative
            Lambda GPU
fold_fun'
            Result
nes
            [VName]
arrs

        ScremaForm SOACS
reduce_soac <- [Reduce SOACS] -> DistribM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Bindable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda -> Result -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> Result -> Reduce rep
Reduce Commutativity
comm' Lambda
red_fun Result
nes]

        (Stms GPU
stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<>)
          (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPU -> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf
            Stms GPU
stms
            ( KernelPath -> Stm SOACS -> DistribM (Stms GPU)
transformStm KernelPath
path' (Stm SOACS -> DistribM (Stms GPU))
-> Stm SOACS -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
                Pattern SOACS -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatternT Type
Pattern SOACS
red_pat StmAux ()
StmAux (ExpDec SOACS)
aux {stmAuxAttrs :: Attrs
stmAuxAttrs = Attrs
forall a. Monoid a => a
mempty} (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
                  Op SOACS -> ExpT SOACS
forall rep. Op rep -> ExpT rep
Op (SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
num_threads [VName]
red_results ScremaForm SOACS
reduce_soac)
            )
      | Bool
otherwise = do
        let red_fun_sequential :: Lambda GPU
red_fun_sequential = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
red_fun
            fold_fun_sequential :: Lambda GPU
fold_fun_sequential = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
fold_fun
        (Stm GPU -> Stm GPU) -> Stms GPU -> Stms GPU
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm GPU -> Stm GPU
forall rep. Certificates -> Stm rep -> Stm rep
certify Certificates
cs)
          (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MkSegLevel GPU DistribM
-> Pattern GPU
-> SubExp
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> Result
-> [VName]
-> DistribM (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
MkSegLevel GPU m
-> Pattern GPU
-> SubExp
-> Commutativity
-> Lambda GPU
-> Lambda GPU
-> Result
-> [VName]
-> m (Stms GPU)
streamRed
            MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped
            Pattern SOACS
Pattern GPU
pat
            SubExp
w
            Commutativity
comm'
            Lambda GPU
red_fun_sequential
            Lambda GPU
fold_fun_sequential
            Result
nes
            [VName]
arrs

    outerParallelBody :: KernelPath -> DistribM (Body GPU)
outerParallelBody KernelPath
path' =
      Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
        (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
paralleliseOuter KernelPath
path' DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))

    paralleliseInner :: KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' = do
      Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
      KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path' ([Stm SOACS] -> DistribM (Stms GPU))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> [Stm SOACS] -> [Stm SOACS]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall rep. Certificates -> Stm rep -> Stm rep
certify Certificates
cs) ([Stm SOACS] -> [Stm SOACS])
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
        (((), Stms SOACS) -> DistribM (Stms GPU))
-> DistribM ((), Stms SOACS) -> DistribM (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT (Pattern (Rep (BinderT SOACS DistribM))
-> SubExp
-> Result
-> LambdaT (Rep (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m)) =>
Pattern (Rep m)
-> SubExp -> Result -> LambdaT (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Rep (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w Result
nes LambdaT (Rep (BinderT SOACS DistribM))
Lambda
fold_fun [VName]
arrs) Scope SOACS
types

    innerParallelBody :: KernelPath -> DistribM (Body GPU)
innerParallelBody KernelPath
path' =
      Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
        (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))

    comm' :: Commutativity
comm'
      | Lambda -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda
red_fun, StreamOrd
o StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamOrd
InOrder = Commutativity
Commutative
      | Bool
otherwise = Commutativity
comm
transformStm KernelPath
path (Let Pattern SOACS
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) = do
  -- This screma is too complicated for us to immediately do
  -- anything, so split it up and try again.
  Scope SOACS
scope <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
  KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> [Stm SOACS] -> [Stm SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (Certificates -> Stm SOACS -> Stm SOACS
forall rep. Certificates -> Stm rep -> Stm rep
certify Certificates
cs) ([Stm SOACS] -> [Stm SOACS])
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
    (((), Stms SOACS) -> DistribM (Stms GPU))
-> DistribM ((), Stms SOACS) -> DistribM (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT (Pattern (Rep (BinderT SOACS DistribM))
-> SubExp
-> ScremaForm (Rep (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Op (Rep m) ~ SOAC (Rep m), Bindable (Rep m)) =>
Pattern (Rep m) -> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pattern (Rep (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w ScremaForm (Rep (BinderT SOACS DistribM))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
_ (Op (Stream SubExp
w [VName]
arrs StreamForm SOACS
Sequential Result
nes Lambda
fold_fun))) = do
  -- Remove the stream and leave the body parallel.  It will be
  -- distributed.
  Scope SOACS
types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
  KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
    (((), Stms SOACS) -> DistribM (Stms GPU))
-> DistribM ((), Stms SOACS) -> DistribM (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT (Pattern (Rep (BinderT SOACS DistribM))
-> SubExp
-> Result
-> LambdaT (Rep (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m)) =>
Pattern (Rep m)
-> SubExp -> Result -> LambdaT (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Rep (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w Result
nes LambdaT (Rep (BinderT SOACS DistribM))
Lambda
fold_fun [VName]
arrs) Scope SOACS
types
transformStm KernelPath
_ (Let Pattern SOACS
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Scatter SubExp
w Lambda
lam [VName]
ivs [(Shape, Int, VName)]
as))) = Binder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder GPU () -> DistribM (Stms GPU))
-> Binder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
  let lam' :: Lambda GPU
lam' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
lam
  VName
write_i <- [Char] -> BinderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_i"
  let ([Shape]
as_ws, [Int]
_, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
      kstms :: Stms GPU
kstms = Body GPU -> Stms GPU
forall rep. BodyT rep -> Stms rep
bodyStms (Body GPU -> Stms GPU) -> Body GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
lam'
      krets :: [KernelResult]
krets = do
        (Shape
a_w, VName
a, [(Result, SubExp)]
is_vs) <-
          [(Shape, Int, VName)]
-> Result -> [(Shape, VName, [(Result, SubExp)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
as (Result -> [(Shape, VName, [(Result, SubExp)])])
-> Result -> [(Shape, VName, [(Result, SubExp)])]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Result
forall rep. BodyT rep -> Result
bodyResult (Body GPU -> Result) -> Body GPU -> Result
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
lam'
        KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Shape
a_w VName
a [((SubExp -> DimIndex SubExp) -> Result -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix Result
is, SubExp
v) | (Result
is, SubExp
v) <- [(Result, SubExp)]
is_vs]
      body :: KernelBody GPU
body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms [KernelResult]
krets
      inputs :: [KernelInput]
inputs = do
        (Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [LParam GPU]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
        KernelInput -> [KernelInput]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> Result -> KernelInput
KernelInput (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]
  (SegOp SegLevel GPU
kernel, Stms GPU
stms) <-
    MkSegLevel GPU (BinderT GPU (State VNameSource))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody GPU
-> BinderT
     GPU (State VNameSource) (SegOp (SegOpLevel GPU) GPU, Stms GPU)
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel
      MkSegLevel GPU (BinderT GPU (State VNameSource))
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped
      [(VName
write_i, SubExp
w)]
      [KernelInput]
inputs
      ((Shape -> Type -> Type) -> [Shape] -> [Type] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> (Shape -> Int) -> Shape -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [Shape]
as_ws ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern SOACS
pat)
      KernelBody GPU
body
  Certificates -> Binder GPU () -> Binder GPU ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (Binder GPU () -> Binder GPU ()) -> Binder GPU () -> Binder GPU ()
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BinderT GPU (State VNameSource)))
Stms GPU
stms
    Pattern (Rep (BinderT GPU (State VNameSource)))
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern (Rep (BinderT GPU (State VNameSource)))
Pattern SOACS
pat (Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ())
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ()
forall a b. (a -> b) -> a -> b
$ Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp SegOp SegLevel GPU
kernel
transformStm KernelPath
_ (Let Pattern SOACS
orig_pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Hist SubExp
w [HistOp SOACS]
ops Lambda
bucket_fun [VName]
imgs))) = do
  let bfun' :: Lambda GPU
bfun' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
bucket_fun

  -- It is important not to launch unnecessarily many threads for
  -- histograms, because it may mean we unnecessarily need to reduce
  -- subhistograms as well.
  Binder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder GPU () -> DistribM (Stms GPU))
-> Binder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    SegLevel
lvl <- MkSegLevel GPU (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] [Char]
"seghist" (ThreadRecommendation
 -> BinderT GPU (State VNameSource) (SegOpLevel GPU))
-> ThreadRecommendation
-> BinderT GPU (State VNameSource) (SegOpLevel GPU)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
    Stms GPU -> Binder GPU ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> Binder GPU ())
-> BinderT GPU (State VNameSource) (Stms GPU) -> Binder GPU ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Lambda
 -> BinderT
      GPU
      (State VNameSource)
      (Lambda (Rep (BinderT GPU (State VNameSource)))))
-> SegOpLevel (Rep (BinderT GPU (State VNameSource)))
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep (BinderT GPU (State VNameSource)))
-> [VName]
-> BinderT
     GPU
     (State VNameSource)
     (Stms (Rep (BinderT GPU (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, DistRep (Rep m)) =>
(Lambda -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda
-> BinderT
     GPU
     (State VNameSource)
     (Lambda (Rep (BinderT GPU (State VNameSource))))
Lambda -> BinderT GPU (State VNameSource) (Lambda GPU)
onLambda SegOpLevel (Rep (BinderT GPU (State VNameSource)))
SegLevel
lvl PatternT Type
Pattern SOACS
orig_pat [] [] Certificates
cs SubExp
w [HistOp SOACS]
ops Lambda (Rep (BinderT GPU (State VNameSource)))
Lambda GPU
bfun' [VName]
imgs
  where
    onLambda :: Lambda -> BinderT GPU (State VNameSource) (Lambda GPU)
onLambda = Lambda GPU -> BinderT GPU (State VNameSource) (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> BinderT GPU (State VNameSource) (Lambda GPU))
-> (Lambda -> Lambda GPU)
-> Lambda
-> BinderT GPU (State VNameSource) (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Lambda GPU
soacsLambdaToGPU
transformStm KernelPath
_ Stm SOACS
bnd =
  Binder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder GPU () -> DistribM (Stms GPU))
-> Binder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Binder GPU ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
bnd

sufficientParallelism ::
  String ->
  [SubExp] ->
  KernelPath ->
  Maybe Int64 ->
  DistribM ((SubExp, Name), Out.Stms Out.GPU)
sufficientParallelism :: [Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
desc Result
ws KernelPath
path Maybe Int64
def =
  [Char]
-> SizeClass -> Result -> DistribM ((SubExp, Name), Stms GPU)
cmpSizeLe [Char]
desc (KernelPath -> Maybe Int64 -> SizeClass
Out.SizeThreshold KernelPath
path Maybe Int64
def) Result
ws

-- | Intra-group parallelism is worthwhile if the lambda contains more
-- than one instance of non-map nested parallelism, or any nested
-- parallelism inside a loop.
worthIntraGroup :: Lambda -> Bool
worthIntraGroup :: Lambda -> Bool
worthIntraGroup Lambda
lam = BodyT SOACS -> Int
forall {rep}. (Op rep ~ SOAC rep) => BodyT rep -> Int
bodyInterest (Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
  where
    bodyInterest :: BodyT rep -> Int
bodyInterest BodyT rep
body =
      Seq Int -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Seq Int -> Int) -> Seq Int -> Int
forall a b. (a -> b) -> a -> b
$ Stm rep -> Int
interest (Stm rep -> Int) -> Seq (Stm rep) -> Seq Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
body
    interest :: Stm rep -> Int
interest Stm rep
stm
      | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
        Int
0 :: Int
      | Op (Screma SubExp
w [VName]
_ ScremaForm rep
form) <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
        Just Lambda rep
lam' <- ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form =
        SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam'
      | Op (Scatter SubExp
w Lambda rep
lam' [VName]
_ [(Shape, Int, VName)]
_) <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam'
      | DoLoop [(FParam rep, SubExp)]
_ [(FParam rep, SubExp)]
_ LoopForm rep
_ BodyT rep
body <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        BodyT rep -> Int
bodyInterest BodyT rep
body Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10
      | If SubExp
_ BodyT rep
tbody BodyT rep
fbody IfDec (BranchType rep)
_ <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (BodyT rep -> Int
bodyInterest BodyT rep
tbody) (BodyT rep -> Int
bodyInterest BodyT rep
fbody)
      | Op (Screma SubExp
w [VName]
_ (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
lam')) <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        SubExp -> Int
forall {p}. Num p => SubExp -> p
zeroIfTooSmall SubExp
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ BodyT rep -> Int
bodyInterest (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam')
      | Op (Stream SubExp
_ [VName]
_ StreamForm rep
Sequential Result
_ Lambda rep
lam') <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        BodyT rep -> Int
bodyInterest (BodyT rep -> Int) -> BodyT rep -> Int
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam'
      | Bool
otherwise =
        Int
0
      where
        attrs :: Attrs
attrs = StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (StmAux (ExpDec rep) -> Attrs) -> StmAux (ExpDec rep) -> Attrs
forall a b. (a -> b) -> a -> b
$ Stm rep -> StmAux (ExpDec rep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm rep
stm
        sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

        zeroIfTooSmall :: SubExp -> p
zeroIfTooSmall (Constant (IntValue IntValue
x))
          | IntValue -> Int64
intToInt64 IntValue
x Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
32 = p
0
        zeroIfTooSmall SubExp
_ = p
1

        mapLike :: SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam' =
          if Bool
sequential_inner
            then Int
0
            else Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (SubExp -> Int
forall {p}. Num p => SubExp -> p
zeroIfTooSmall SubExp
w) (BodyT rep -> Int
bodyInterest (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam'))

-- | A lambda is worth sequentialising if it contains enough nested
-- parallelism of an interesting kind.
worthSequentialising :: Lambda -> Bool
worthSequentialising :: Lambda -> Bool
worthSequentialising Lambda
lam = BodyT SOACS -> Int
forall {rep}. (Op rep ~ SOAC rep) => BodyT rep -> Int
bodyInterest (Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
  where
    bodyInterest :: BodyT rep -> Int
bodyInterest BodyT rep
body =
      Seq Int -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Seq Int -> Int) -> Seq Int -> Int
forall a b. (a -> b) -> a -> b
$ Stm rep -> Int
interest (Stm rep -> Int) -> Seq (Stm rep) -> Seq Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
body
    interest :: Stm rep -> Int
interest Stm rep
stm
      | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
        Int
0 :: Int
      | Op (Screma SubExp
_ [VName]
_ form :: ScremaForm rep
form@(ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
lam')) <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
        Maybe (Lambda rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda rep) -> Bool) -> Maybe (Lambda rep) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form =
        if Bool
sequential_inner
          then Int
0
          else BodyT rep -> Int
bodyInterest (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam')
      | Op Scatter {} <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        Int
0 -- Basically a map.
      | DoLoop [(FParam rep, SubExp)]
_ [(FParam rep, SubExp)]
_ ForLoop {} BodyT rep
body <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        BodyT rep -> Int
bodyInterest BodyT rep
body Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10
      | WithAcc [(Shape, [VName], Maybe (Lambda rep, Result))]
_ Lambda rep
withacc_lam <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        BodyT rep -> Int
bodyInterest (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
withacc_lam)
      | Op (Screma SubExp
_ [VName]
_ form :: ScremaForm rep
form@(ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
lam')) <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ BodyT rep -> Int
bodyInterest (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam')
          Int -> Int -> Int
forall a. Num a => a -> a -> a
+
          -- Give this a bigger score if it's a redomap, as these
          -- are often tileable and thus benefit more from
          -- sequentialisation.
          case ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form of
            Just ([Reduce rep], Lambda rep)
_ -> Int
1
            Maybe ([Reduce rep], Lambda rep)
Nothing -> Int
0
      | Bool
otherwise =
        Int
0
      where
        attrs :: Attrs
attrs = StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (StmAux (ExpDec rep) -> Attrs) -> StmAux (ExpDec rep) -> Attrs
forall a b. (a -> b) -> a -> b
$ Stm rep -> StmAux (ExpDec rep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm rep
stm
        sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

onTopLevelStms ::
  KernelPath ->
  Stms SOACS ->
  DistNestT Out.GPU DistribM GPUStms
onTopLevelStms :: KernelPath -> Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
onTopLevelStms KernelPath
path Stms SOACS
stms =
  DistribM (Stms GPU) -> DistNestT GPU DistribM (Stms GPU)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (DistribM (Stms GPU) -> DistNestT GPU DistribM (Stms GPU))
-> DistribM (Stms GPU) -> DistNestT GPU DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> [Stm SOACS] -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms

onMap :: KernelPath -> MapLoop -> DistribM GPUStms
onMap :: KernelPath -> MapLoop -> DistribM (Stms GPU)
onMap KernelPath
path (MapLoop Pattern SOACS
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) = do
  Scope GPU
types <- DistribM (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let loopnest :: LoopNesting
loopnest = PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
Pattern SOACS
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda
lam) [VName]
arrs
      env :: KernelPath -> DistEnv GPU DistribM
env KernelPath
path' =
        DistEnv :: forall rep (m :: * -> *).
Nestings
-> Scope rep
-> (Stms SOACS -> DistNestT rep m (Stms rep))
-> (MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep))
-> (Stm SOACS -> Binder rep (Stms rep))
-> (Lambda -> Binder rep (Lambda rep))
-> MkSegLevel rep m
-> DistEnv rep m
DistEnv
          { distNest :: Nestings
distNest = Nesting -> Nestings
singleNesting (Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest),
            distScope :: Scope GPU
distScope =
              PatternT Type -> Scope GPU
forall rep dec. (LetDec rep ~ dec) => PatternT dec -> Scope rep
scopeOfPattern PatternT Type
Pattern SOACS
pat
                Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (Lambda -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda
lam)
                Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope GPU
types,
            distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
distOnInnerMap = KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path',
            distOnTopLevelStms :: Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
distOnTopLevelStms = KernelPath -> Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
onTopLevelStms KernelPath
path',
            distSegLevel :: MkSegLevel GPU DistribM
distSegLevel = MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped,
            distOnSOACSStms :: Stm SOACS -> BinderT GPU (State VNameSource) (Stms GPU)
distOnSOACSStms = Stms GPU -> BinderT GPU (State VNameSource) (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> BinderT GPU (State VNameSource) (Stms GPU))
-> (Stm SOACS -> Stms GPU)
-> Stm SOACS
-> BinderT GPU (State VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Stm SOACS -> Stm GPU) -> Stm SOACS -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm GPU
soacsStmToGPU,
            distOnSOACSLambda :: Lambda -> BinderT GPU (State VNameSource) (Lambda GPU)
distOnSOACSLambda = Lambda GPU -> BinderT GPU (State VNameSource) (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> BinderT GPU (State VNameSource) (Lambda GPU))
-> (Lambda -> Lambda GPU)
-> Lambda
-> BinderT GPU (State VNameSource) (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Lambda GPU
soacsLambdaToGPU
          }
      exploitInnerParallelism :: KernelPath -> DistribM (Stms GPU)
exploitInnerParallelism KernelPath
path' =
        DistEnv GPU DistribM
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT (KernelPath -> DistEnv GPU DistribM
env KernelPath
path') (DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU))
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
          DistAcc GPU -> Stms SOACS -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam)

  let exploitOuterParallelism :: KernelPath -> DistribM (Stms GPU)
exploitOuterParallelism KernelPath
path' = do
        let lam' :: Lambda GPU
lam' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
lam
        DistEnv GPU DistribM
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT (KernelPath -> DistEnv GPU DistribM
env KernelPath
path') (DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU))
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
          DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute (DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU))
-> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall a b. (a -> b) -> a -> b
$
            Stms GPU -> DistAcc GPU -> DistAcc GPU
forall rep. Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc (Body GPU -> Stms GPU
forall rep. BodyT rep -> Stms rep
bodyStms (Body GPU -> Stms GPU) -> Body GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
lam') DistAcc GPU
acc

  KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms GPU))
-> (KernelPath -> DistribM (Stms GPU))
-> Pattern SOACS
-> Lambda
-> DistribM (Stms GPU)
onMap' (LoopNesting -> KernelNest
newKernel LoopNesting
loopnest) KernelPath
path KernelPath -> DistribM (Stms GPU)
exploitOuterParallelism KernelPath -> DistribM (Stms GPU)
exploitInnerParallelism Pattern SOACS
pat Lambda
lam
  where
    acc :: DistAcc GPU
acc =
      DistAcc :: forall rep. Targets -> Stms rep -> DistAcc rep
DistAcc
        { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (PatternT Type
Pattern SOACS
pat, BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam),
          distStms :: Stms GPU
distStms = Stms GPU
forall a. Monoid a => a
mempty
        }

onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra Attrs
attrs =
  Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_intra"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

mayExploitOuter :: Attrs -> Bool
mayExploitOuter :: Attrs -> Bool
mayExploitOuter Attrs
attrs =
  Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_outer"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
      Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

mayExploitIntra :: Attrs -> Bool
mayExploitIntra :: Attrs -> Bool
mayExploitIntra Attrs
attrs =
  Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_intra"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
      Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

-- The minimum amount of inner parallelism we require (by default) in
-- intra-group versions.  Less than this is usually pointless on a GPU
-- (but we allow tuning to change it).
intraMinInnerPar :: Int64
intraMinInnerPar :: Int64
intraMinInnerPar = Int64
32 -- One NVIDIA warp

onMap' ::
  KernelNest ->
  KernelPath ->
  (KernelPath -> DistribM (Out.Stms Out.GPU)) ->
  (KernelPath -> DistribM (Out.Stms Out.GPU)) ->
  Pattern ->
  Lambda ->
  DistribM (Out.Stms Out.GPU)
onMap' :: KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms GPU))
-> (KernelPath -> DistribM (Stms GPU))
-> Pattern SOACS
-> Lambda
-> DistribM (Stms GPU)
onMap' KernelNest
loopnest KernelPath
path KernelPath -> DistribM (Stms GPU)
mk_seq_stms KernelPath -> DistribM (Stms GPU)
mk_par_stms Pattern SOACS
pat Lambda
lam = do
  -- Some of the control flow here looks a bit convoluted because we
  -- are trying to avoid generating unneeded threshold parameters,
  -- which means we need to do all the pruning checks up front.

  Scope GPU
types <- DistribM (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope

  Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra <-
    if Attrs -> Bool
onlyExploitIntra (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux)
      Bool -> Bool -> Bool
|| (Lambda -> Bool
worthIntraGroup Lambda
lam Bool -> Bool -> Bool
&& Attrs -> Bool
mayExploitIntra Attrs
attrs)
      then (ReaderT
   (Scope GPU)
   DistribM
   (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
 -> Scope GPU
 -> DistribM
      (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)))
-> Scope GPU
-> ReaderT
     (Scope GPU)
     DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
-> DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT
  (Scope GPU)
  DistribM
  (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
-> Scope GPU
-> DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT Scope GPU
types (ReaderT
   (Scope GPU)
   DistribM
   (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
 -> DistribM
      (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)))
-> ReaderT
     (Scope GPU)
     DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
-> DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall a b. (a -> b) -> a -> b
$ KernelNest
-> Lambda
-> ReaderT
     (Scope GPU)
     DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall (m :: * -> *).
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest
-> Lambda
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
intraGroupParallelise KernelNest
loopnest Lambda
lam
      else Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> DistribM
     (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
forall a. Maybe a
Nothing

  case Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra of
    Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
_ | Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs -> do
      Body GPU
seq_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_seq_stms KernelPath
path DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
      Pattern GPU
-> Body GPU -> [(SubExp, Body GPU)] -> DistribM (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pattern SOACS
Pattern GPU
pat Body GPU
seq_body []
    --
    Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
Nothing
      | Just DistribM (SubExp, Name, Stms GPU, Body GPU)
m <- Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
mkSeqAlts -> do
        (SubExp
outer_suff, Name
outer_suff_key, Stms GPU
outer_suff_stms, Body GPU
seq_body) <- DistribM (SubExp, Name, Stms GPU, Body GPU)
m
        Body GPU
par_body <-
          Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody
            (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_par_stms ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
        (Stms GPU
outer_suff_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<>) (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern GPU
-> Body GPU -> [(SubExp, Body GPU)] -> DistribM (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pattern SOACS
Pattern GPU
pat Body GPU
par_body [(SubExp
outer_suff, Body GPU
seq_body)]
      --
      | Bool
otherwise -> do
        Body GPU
par_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_par_stms KernelPath
path DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
        Pattern GPU
-> Body GPU -> [(SubExp, Body GPU)] -> DistribM (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pattern SOACS
Pattern GPU
pat Body GPU
par_body []
    --
    Just intra' :: ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra'@((SubExp, SubExp)
_, SubExp
_, Log
log, Stms GPU
intra_prelude, Stms GPU
intra_stms)
      | Attrs -> Bool
onlyExploitIntra Attrs
attrs -> do
        Log -> DistribM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
        Body GPU
group_par_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> Body GPU -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
intra_stms Result
res
        (Stms GPU
intra_prelude Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<>) (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern GPU
-> Body GPU -> [(SubExp, Body GPU)] -> DistribM (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pattern SOACS
Pattern GPU
pat Body GPU
group_par_body []
      --
      | Bool
otherwise -> do
        Log -> DistribM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log

        case Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
mkSeqAlts of
          Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
Nothing -> do
            (Body GPU
group_par_body, SubExp
intra_ok, Name
intra_suff_key, Stms GPU
intra_suff_stms) <-
              KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> DistribM (Body GPU, SubExp, Name, Stms GPU)
checkSuffIntraPar KernelPath
path ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra'

            Body GPU
par_body <-
              Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody
                (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_par_stms ((Name
intra_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

            (Stms GPU
intra_suff_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<>)
              (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern GPU
-> Body GPU -> [(SubExp, Body GPU)] -> DistribM (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pattern SOACS
Pattern GPU
pat Body GPU
par_body [(SubExp
intra_ok, Body GPU
group_par_body)]
          Just DistribM (SubExp, Name, Stms GPU, Body GPU)
m -> do
            (SubExp
outer_suff, Name
outer_suff_key, Stms GPU
outer_suff_stms, Body GPU
seq_body) <- DistribM (SubExp, Name, Stms GPU, Body GPU)
m

            (Body GPU
group_par_body, SubExp
intra_ok, Name
intra_suff_key, Stms GPU
intra_suff_stms) <-
              KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> DistribM (Body GPU, SubExp, Name, Stms GPU)
checkSuffIntraPar ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra'

            Body GPU
par_body <-
              Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody
                (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_par_stms
                  ( [ (Name
outer_suff_key, Bool
False),
                      (Name
intra_suff_key, Bool
False)
                    ]
                      KernelPath -> KernelPath -> KernelPath
forall a. [a] -> [a] -> [a]
++ KernelPath
path
                  )
                  DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

            ((Stms GPU
outer_suff_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
intra_suff_stms) Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<>)
              (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern GPU
-> Body GPU -> [(SubExp, Body GPU)] -> DistribM (Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pattern GPU -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives
                Pattern SOACS
Pattern GPU
pat
                Body GPU
par_body
                [(SubExp
outer_suff, Body GPU
seq_body), (SubExp
intra_ok, Body GPU
group_par_body)]
  where
    nest_ws :: Result
nest_ws = KernelNest -> Result
kernelNestWidths KernelNest
loopnest
    res :: Result
res = (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat
    aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux (LoopNesting -> StmAux ()) -> LoopNesting -> StmAux ()
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
innermostKernelNesting KernelNest
loopnest
    attrs :: Attrs
attrs = StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux

    mkSeqAlts :: Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
mkSeqAlts
      | Lambda -> Bool
worthSequentialising Lambda
lam,
        Attrs -> Bool
mayExploitOuter Attrs
attrs = DistribM (SubExp, Name, Stms GPU, Body GPU)
-> Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
forall a. a -> Maybe a
Just (DistribM (SubExp, Name, Stms GPU, Body GPU)
 -> Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU)))
-> DistribM (SubExp, Name, Stms GPU, Body GPU)
-> Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
forall a b. (a -> b) -> a -> b
$ do
        ((SubExp
outer_suff, Name
outer_suff_key), Stms GPU
outer_suff_stms) <- DistribM ((SubExp, Name), Stms GPU)
checkSuffOuterPar
        Body GPU
seq_body <-
          Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody
            (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_seq_stms ((Name
outer_suff_key, Bool
True) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
        (SubExp, Name, Stms GPU, Body GPU)
-> DistribM (SubExp, Name, Stms GPU, Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
outer_suff, Name
outer_suff_key, Stms GPU
outer_suff_stms, Body GPU
seq_body)
      | Bool
otherwise =
        Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
forall a. Maybe a
Nothing

    checkSuffOuterPar :: DistribM ((SubExp, Name), Stms GPU)
checkSuffOuterPar =
      [Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
"suff_outer_par" Result
nest_ws KernelPath
path Maybe Int64
forall a. Maybe a
Nothing

    checkSuffIntraPar :: KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> DistribM (Body GPU, SubExp, Name, Stms GPU)
checkSuffIntraPar
      KernelPath
path'
      ((SubExp
_intra_min_par, SubExp
intra_avail_par), SubExp
group_size, Log
_, Stms GPU
intra_prelude, Stms GPU
intra_stms) = do
        -- We must check that all intra-group parallelism fits in a group.
        ((SubExp
intra_ok, Name
intra_suff_key), Stms GPU
intra_suff_stms) <- do
          ((SubExp
intra_suff, Name
suff_key), Stms GPU
check_suff_stms) <-
            [Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism
              [Char]
"suff_intra_par"
              [SubExp
intra_avail_par]
              KernelPath
path'
              (Int64 -> Maybe Int64
forall a. a -> Maybe a
Just Int64
intraMinInnerPar)

          Binder GPU (SubExp, Name) -> DistribM ((SubExp, Name), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder GPU (SubExp, Name) -> DistribM ((SubExp, Name), Stms GPU))
-> Binder GPU (SubExp, Name) -> DistribM ((SubExp, Name), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
            Stms (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BinderT GPU (State VNameSource)))
Stms GPU
intra_prelude

            SubExp
max_group_size <-
              [Char]
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"max_group_size" (Exp (Rep (BinderT GPU (State VNameSource)))
 -> BinderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp GPU (SOAC GPU))
-> SizeOp -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
Out.GetSizeMax SizeClass
Out.SizeGroup
            SubExp
fits <-
              [Char]
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"fits" (Exp (Rep (BinderT GPU (State VNameSource)))
 -> BinderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
                BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$
                  CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
group_size SubExp
max_group_size

            Stms (Rep (BinderT GPU (State VNameSource))) -> Binder GPU ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BinderT GPU (State VNameSource)))
Stms GPU
check_suff_stms

            SubExp
intra_ok <- [Char]
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"intra_suff_and_fits" (Exp (Rep (BinderT GPU (State VNameSource)))
 -> BinderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
fits SubExp
intra_suff
            (SubExp, Name) -> Binder GPU (SubExp, Name)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
intra_ok, Name
suff_key)

        Body GPU
group_par_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> Body GPU -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> Body GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
intra_stms Result
res
        (Body GPU, SubExp, Name, Stms GPU)
-> DistribM (Body GPU, SubExp, Name, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU
group_par_body, SubExp
intra_ok, Name
intra_suff_key, Stms GPU
intra_suff_stms)

onInnerMap ::
  KernelPath ->
  MapLoop ->
  DistAcc Out.GPU ->
  DistNestT Out.GPU DistribM (DistAcc Out.GPU)
onInnerMap :: KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path maploop :: MapLoop
maploop@(MapLoop Pattern SOACS
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) DistAcc GPU
acc
  | Lambda -> Bool
unbalancedLambda Lambda
lam,
    Lambda -> Bool
lambdaContainsParallelism Lambda
lam =
    Stm SOACS -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) DistAcc GPU
acc
  | Bool
otherwise =
    DistAcc GPU
-> Stm SOACS
-> DistNestT
     GPU
     DistribM
     (Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc GPU
acc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) DistNestT
  GPU
  DistribM
  (Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU))
-> (Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU)
    -> DistNestT GPU DistribM (DistAcc GPU))
-> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms GPU
post_kernels, Result
res, KernelNest
nest, DistAcc GPU
acc')
        | Just ([Int]
perm, [PatElemT Type]
_pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern SOACS
pat Result
res -> do
          PostStms GPU -> DistNestT GPU DistribM ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms GPU
post_kernels
          [Int]
-> KernelNest
-> DistAcc GPU
-> DistNestT GPU DistribM (DistAcc GPU)
multiVersion [Int]
perm KernelNest
nest DistAcc GPU
acc'
      Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU)
_ -> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap MapLoop
maploop DistAcc GPU
acc
  where
    discardTargets :: DistAcc rep -> DistAcc rep
discardTargets DistAcc rep
acc' =
      -- FIXME: work around bogus targets.
      DistAcc rep
acc' {distTargets :: Targets
distTargets = Target -> Targets
singleTarget (PatternT Type
forall a. Monoid a => a
mempty, Result
forall a. Monoid a => a
mempty)}

    multiVersion :: [Int]
-> KernelNest
-> DistAcc GPU
-> DistNestT GPU DistribM (DistAcc GPU)
multiVersion [Int]
perm KernelNest
nest DistAcc GPU
acc' = do
      -- The kernel can be distributed by itself, so now we can
      -- decide whether to just sequentialise, or exploit inner
      -- parallelism.
      DistEnv GPU DistribM
dist_env <- DistNestT GPU DistribM (DistEnv GPU DistribM)
forall r (m :: * -> *). MonadReader r m => m r
ask
      let extra_scope :: Scope GPU
extra_scope = Targets -> Scope GPU
forall rep. DistRep rep => Targets -> Scope rep
targetsScope (Targets -> Scope GPU) -> Targets -> Scope GPU
forall a b. (a -> b) -> a -> b
$ DistAcc GPU -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc GPU
acc'

      Stms GPU
stms <- DistribM (Stms GPU) -> DistNestT GPU DistribM (Stms GPU)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (DistribM (Stms GPU) -> DistNestT GPU DistribM (Stms GPU))
-> DistribM (Stms GPU) -> DistNestT GPU DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
        Scope GPU -> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope (DistribM (Stms GPU) -> DistribM (Stms GPU))
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
          let maploop' :: MapLoop
maploop' = Pattern SOACS
-> StmAux () -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern SOACS
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs

              exploitInnerParallelism :: KernelPath -> DistribM (Stms GPU)
exploitInnerParallelism KernelPath
path' = do
                let dist_env' :: DistEnv GPU DistribM
dist_env' =
                      DistEnv GPU DistribM
dist_env
                        { distOnTopLevelStms :: Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
distOnTopLevelStms = KernelPath -> Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
onTopLevelStms KernelPath
path',
                          distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
distOnInnerMap = KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path'
                        }
                DistEnv GPU DistribM
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU DistribM
dist_env' (DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU))
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
                  KernelNest
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
KernelNest -> DistNestT rep m a -> DistNestT rep m a
inNesting KernelNest
nest (DistNestT GPU DistribM (DistAcc GPU)
 -> DistNestT GPU DistribM (DistAcc GPU))
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall a b. (a -> b) -> a -> b
$
                    Scope GPU
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope (DistNestT GPU DistribM (DistAcc GPU)
 -> DistNestT GPU DistribM (DistAcc GPU))
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall a b. (a -> b) -> a -> b
$
                      DistAcc GPU -> DistAcc GPU
forall {rep}. DistAcc rep -> DistAcc rep
discardTargets (DistAcc GPU -> DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap MapLoop
maploop' DistAcc GPU
acc {distStms :: Stms GPU
distStms = Stms GPU
forall a. Monoid a => a
mempty}

          -- Normally the permutation is for the output pattern, but
          -- we can't really change that, so we change the result
          -- order instead.
          let lam_res' :: Result
lam_res' =
                [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$
                  BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam
              lam' :: Lambda
lam' = Lambda
lam {lambdaBody :: BodyT SOACS
lambdaBody = (Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam) {bodyResult :: Result
bodyResult = Result
lam_res'}}
              map_nesting :: LoopNesting
map_nesting = PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
Pattern SOACS
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda
lam) [VName]
arrs
              nest' :: KernelNest
nest' = Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (PatternT Type
Pattern SOACS
pat, Result
lam_res') LoopNesting
map_nesting KernelNest
nest

          -- XXX: we do not construct a new KernelPath when
          -- sequentialising.  This is only OK as long as further
          -- versioning does not take place down that branch (it currently
          -- does not).
          (Stm GPU
sequentialised_kernel, Stms GPU
nestw_bnds) <- Scope GPU
-> DistribM (Stm GPU, Stms GPU) -> DistribM (Stm GPU, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope (DistribM (Stm GPU, Stms GPU) -> DistribM (Stm GPU, Stms GPU))
-> DistribM (Stm GPU, Stms GPU) -> DistribM (Stm GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
            let sequentialised_lam :: Lambda GPU
sequentialised_lam = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
lam'
            MkSegLevel GPU DistribM
-> KernelNest -> Body GPU -> DistribM (Stm GPU, Stms GPU)
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m) =>
MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep)
constructKernel MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped KernelNest
nest' (Body GPU -> DistribM (Stm GPU, Stms GPU))
-> Body GPU -> DistribM (Stm GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
sequentialised_lam

          let outer_pat :: PatternT Type
outer_pat = LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
          (Stms GPU
nestw_bnds Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<>)
            (Stms GPU -> Stms GPU)
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms GPU))
-> (KernelPath -> DistribM (Stms GPU))
-> Pattern SOACS
-> Lambda
-> DistribM (Stms GPU)
onMap'
              KernelNest
nest'
              KernelPath
path
              (DistribM (Stms GPU) -> KernelPath -> DistribM (Stms GPU)
forall a b. a -> b -> a
const (DistribM (Stms GPU) -> KernelPath -> DistribM (Stms GPU))
-> DistribM (Stms GPU) -> KernelPath -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> DistribM (Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPU -> DistribM (Stms GPU))
-> Stms GPU -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
sequentialised_kernel)
              KernelPath -> DistribM (Stms GPU)
exploitInnerParallelism
              PatternT Type
Pattern SOACS
outer_pat
              Lambda
lam'

      Stms GPU -> DistNestT GPU DistribM ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm Stms GPU
stms
      DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc GPU
acc'