{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}

module Futhark.Pass.ExtractKernels.DistributeNests
  ( MapLoop (..),
    mapLoopStm,
    bodyContainsParallelism,
    lambdaContainsParallelism,
    determineReduceOp,
    histKernel,
    DistEnv (..),
    DistAcc (..),
    runDistNestT,
    DistNestT,
    liftInner,
    distributeMap,
    distribute,
    distributeSingleStm,
    distributeMapBodyStms,
    addStmsToAcc,
    addStmToAcc,
    permutationAndMissing,
    addPostStms,
    postStm,
    inNesting,
  )
where

import Control.Arrow (first)
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Trans.Maybe
import Control.Monad.Writer.Strict
import Data.Function ((&))
import Data.List (find, partition, tails)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map qualified as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.SOACS (SOACS)
import Futhark.IR.SOACS qualified as SOACS
import Futhark.IR.SOACS.SOAC hiding (HistOp, histDest)
import Futhark.IR.SOACS.Simplify (simpleSOACS, simplifyStms)
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Tools
import Futhark.Transform.CopyPropagate
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Transform.Rename
import Futhark.Util
import Futhark.Util.Log

scopeForSOACs :: SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs :: forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs = forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope

data MapLoop = MapLoop (Pat Type) (StmAux ()) SubExp (Lambda SOACS) [VName]

mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm (MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) =
  forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam

data DistEnv rep m = DistEnv
  { forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest :: Nestings,
    forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope :: Scope rep,
    forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
distOnTopLevelStms :: Stms SOACS -> DistNestT rep m (Stms rep),
    forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distOnInnerMap ::
      MapLoop ->
      DistAcc rep ->
      DistNestT rep m (DistAcc rep),
    forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
distOnSOACSStms :: Stm SOACS -> Builder rep (Stms rep),
    forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda :: Lambda SOACS -> Builder rep (Lambda rep),
    forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel :: MkSegLevel rep m
  }

data DistAcc rep = DistAcc
  { forall {k} (rep :: k). DistAcc rep -> Targets
distTargets :: Targets,
    forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms :: Stms rep
  }

data DistRes rep = DistRes
  { forall {k} (rep :: k). DistRes rep -> PostStms rep
accPostStms :: PostStms rep,
    forall {k} (rep :: k). DistRes rep -> Log
accLog :: Log
  }

instance Semigroup (DistRes rep) where
  DistRes PostStms rep
ks1 Log
log1 <> :: DistRes rep -> DistRes rep -> DistRes rep
<> DistRes PostStms rep
ks2 Log
log2 =
    forall {k} (rep :: k). PostStms rep -> Log -> DistRes rep
DistRes (PostStms rep
ks1 forall a. Semigroup a => a -> a -> a
<> PostStms rep
ks2) (Log
log1 forall a. Semigroup a => a -> a -> a
<> Log
log2)

instance Monoid (DistRes rep) where
  mempty :: DistRes rep
mempty = forall {k} (rep :: k). PostStms rep -> Log -> DistRes rep
DistRes forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty

newtype PostStms rep = PostStms {forall {k} (rep :: k). PostStms rep -> Stms rep
unPostStms :: Stms rep}

instance Semigroup (PostStms rep) where
  PostStms Stms rep
xs <> :: PostStms rep -> PostStms rep -> PostStms rep
<> PostStms Stms rep
ys = forall {k} (rep :: k). Stms rep -> PostStms rep
PostStms forall a b. (a -> b) -> a -> b
$ Stms rep
ys forall a. Semigroup a => a -> a -> a
<> Stms rep
xs

instance Monoid (PostStms rep) where
  mempty :: PostStms rep
mempty = forall {k} (rep :: k). Stms rep -> PostStms rep
PostStms forall a. Monoid a => a
mempty

typeEnvFromDistAcc :: DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc :: forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc = forall {k} (rep :: k) dec.
(LetDec rep ~ dec) =>
Pat dec -> Scope rep
scopeOfPat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (Pat Type, Result)
outerTarget forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). DistAcc rep -> Targets
distTargets

addStmsToAcc :: Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc :: forall {k} (rep :: k). Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc Stms rep
stms DistAcc rep
acc =
  DistAcc rep
acc {distStms :: Stms rep
distStms = Stms rep
stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc}

addStmToAcc ::
  (MonadFreshNames m, DistRep rep) =>
  Stm SOACS ->
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
addStmToAcc :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc = do
  Stm SOACS -> Builder rep (Stms rep)
onSoacs <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
distOnSOACSStms
  (Stms rep
stm', Stms rep
_) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Builder rep (Stms rep)
onSoacs Stm SOACS
stm
  forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc {distStms :: Stms rep
distStms = Stms rep
stm' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc}

soacsLambda ::
  (MonadFreshNames m, DistRep rep) =>
  Lambda SOACS ->
  DistNestT rep m (Lambda rep)
soacsLambda :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam = do
  Lambda SOACS -> Builder rep (Lambda rep)
onLambda <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda
  forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Lambda SOACS -> Builder rep (Lambda rep)
onLambda Lambda SOACS
lam)

newtype DistNestT rep m a
  = DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a)
  deriving
    ( forall k (rep :: k) (m :: * -> *) a b.
Functor m =>
a -> DistNestT rep m b -> DistNestT rep m a
forall k (rep :: k) (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall a b. a -> DistNestT rep m b -> DistNestT rep m a
forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DistNestT rep m b -> DistNestT rep m a
$c<$ :: forall k (rep :: k) (m :: * -> *) a b.
Functor m =>
a -> DistNestT rep m b -> DistNestT rep m a
fmap :: forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b
$cfmap :: forall k (rep :: k) (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT rep m a -> DistNestT rep m b
Functor,
      forall a. a -> DistNestT rep m a
forall {k} {rep :: k} {m :: * -> *}.
Applicative m =>
Functor (DistNestT rep m)
forall k (rep :: k) (m :: * -> *) a.
Applicative m =>
a -> DistNestT rep m a
forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall k (rep :: k) (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall a b.
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall a b c.
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
$c<* :: forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
*> :: forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
$c*> :: forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
liftA2 :: forall a b c.
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
$cliftA2 :: forall k (rep :: k) (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
<*> :: forall a b.
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
$c<*> :: forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
pure :: forall a. a -> DistNestT rep m a
$cpure :: forall k (rep :: k) (m :: * -> *) a.
Applicative m =>
a -> DistNestT rep m a
Applicative,
      forall a. a -> DistNestT rep m a
forall {k} {rep :: k} {m :: * -> *}.
Monad m =>
Applicative (DistNestT rep m)
forall k (rep :: k) (m :: * -> *) a.
Monad m =>
a -> DistNestT rep m a
forall k (rep :: k) (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall k (rep :: k) (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> DistNestT rep m a
$creturn :: forall k (rep :: k) (m :: * -> *) a.
Monad m =>
a -> DistNestT rep m a
>> :: forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
$c>> :: forall k (rep :: k) (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
>>= :: forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
$c>>= :: forall k (rep :: k) (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
Monad,
      MonadReader (DistEnv rep m),
      MonadWriter (DistRes rep)
    )

liftInner :: (LocalScope rep m, DistRep rep) => m a -> DistNestT rep m a
liftInner :: forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner m a
m = do
  Scope rep
outer_scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  forall {k} (rep :: k) (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT forall a b. (a -> b) -> a -> b
$
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
      forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ do
        Scope rep
inner_scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
        forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep
outer_scope forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Scope rep
inner_scope) m a
m

instance MonadFreshNames m => MonadFreshNames (DistNestT rep m) where
  getNameSource :: DistNestT rep m VNameSource
getNameSource = forall {k} (rep :: k) (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  putNameSource :: VNameSource -> DistNestT rep m ()
putNameSource = forall {k} (rep :: k) (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource

instance (Monad m, ASTRep rep) => HasScope rep (DistNestT rep m) where
  askScope :: DistNestT rep m (Scope rep)
askScope = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope

instance (Monad m, ASTRep rep) => LocalScope rep (DistNestT rep m) where
  localScope :: forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
localScope Scope rep
types = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
    DistEnv rep m
env {distScope :: Scope rep
distScope = Scope rep
types forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env}

instance Monad m => MonadLogger (DistNestT rep m) where
  addLog :: Log -> DistNestT rep m ()
addLog Log
msgs = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
msgs}

runDistNestT ::
  (MonadLogger m, DistRep rep) =>
  DistEnv rep m ->
  DistNestT rep m (DistAcc rep) ->
  m (Stms rep)
runDistNestT :: forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv rep m
env (DistNestT ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
m) = do
  (DistAcc rep
acc, DistRes rep
res) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
m DistEnv rep m
env
  forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistRes rep -> Log
accLog DistRes rep
res
  -- There may be a few final targets remaining - these correspond to
  -- arrays that are identity mapped, and must have statements
  -- inserted here.
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k). PostStms rep -> Stms rep
unPostStms (forall {k} (rep :: k). DistRes rep -> PostStms rep
accPostStms DistRes rep
res) forall a. Semigroup a => a -> a -> a
<> (Pat Type, Result) -> Stms rep
identityStms (Targets -> (Pat Type, Result)
outerTarget forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc)
  where
    outermost :: LoopNesting
outermost = Nesting -> LoopNesting
nestingLoop forall a b. (a -> b) -> a -> b
$
      case forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env of
        (Nesting
nest, []) -> Nesting
nest
        (Nesting
_, Nesting
nest : [Nesting]
_) -> Nesting
nest
    params_to_arrs :: [(VName, VName)]
params_to_arrs =
      forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first forall dec. Param dec -> VName
paramName) forall a b. (a -> b) -> a -> b
$
        LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outermost

    identityStms :: (Pat Type, Result) -> Stms rep
identityStms (Pat Type
rem_pat, Result
res) =
      forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem Type -> SubExpRes -> Stm rep
identityStm (forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
rem_pat) Result
res
    identityStm :: PatElem Type -> SubExpRes -> Stm rep
identityStm PatElem Type
pe (SubExpRes Certs
cs (Var VName
v))
      | Just VName
arr <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
          forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
    identityStm PatElem Type
pe (SubExpRes Certs
cs SubExp
se) =
      forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. dec -> StmAux dec
defAux ()) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [LoopNesting -> SubExp
loopNestingWidth LoopNesting
outermost]) SubExp
se

addPostStms :: Monad m => PostStms rep -> DistNestT rep m ()
addPostStms :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
ks = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => a
mempty {accPostStms :: PostStms rep
accPostStms = PostStms rep
ks}

postStm :: Monad m => Stms rep -> DistNestT rep m ()
postStm :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm Stms rep
stms = forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> PostStms rep
PostStms Stms rep
stms

withStm ::
  (Monad m, DistRep rep) =>
  Stm SOACS ->
  DistNestT rep m a ->
  DistNestT rep m a
withStm :: forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
Stm SOACS -> DistNestT rep m a -> DistNestT rep m a
withStm Stm SOACS
stm = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
  DistEnv rep m
env
    { distScope :: Scope rep
distScope =
        forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stm SOACS
stm) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env,
      distNest :: Nestings
distNest =
        Names -> Nestings -> Nestings
letBindInInnerNesting Names
provided forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env
    }
  where
    provided :: Names
provided = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm

leavingNesting ::
  (MonadFreshNames m, DistRep rep) =>
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
leavingNesting :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
leavingNesting DistAcc rep
acc =
  case Targets -> Maybe ((Pat Type, Result), Targets)
popInnerTarget forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc of
    Maybe ((Pat Type, Result), Targets)
Nothing ->
      forall a. HasCallStack => [Char] -> a
error [Char]
"The kernel targets list is unexpectedly small"
    Just ((Pat Type
pat, Result
res), Targets
newtargets)
      | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc -> do
          -- Any statements left over correspond to something that
          -- could not be distributed because it would cause irregular
          -- arrays.  These must be reconstructed into a a Map SOAC
          -- that will be sequentialised. XXX: life would be better if
          -- we were able to distribute irregular parallelism.
          (Nesting Names
_ LoopNesting
inner, [Nesting]
_) <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest
          let MapNesting Pat Type
_ StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs = LoopNesting
inner
              body :: Body rep
body = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc) Result
res
              used_in_body :: Names
used_in_body = forall a. FreeIn a => a -> Names
freeIn Body rep
body
              ([Param Type]
used_params, [VName]
used_arrs) =
                forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
                  forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used_in_body) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
              lam' :: Lambda rep
lam' =
                Lambda
                  { lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
used_params,
                    lambdaBody :: Body rep
lambdaBody = Body rep
body,
                    lambdaReturnType :: [Type]
lambdaReturnType = forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
pat
                  }
          Stms rep
stms <-
            forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat Type
pat forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
used_arrs forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
lam'

          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DistAcc rep
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Stms rep
distStms = Stms rep
stms}
      | Bool
otherwise -> do
          -- Any results left over correspond to a Replicate or a Copy in
          -- the parent nesting, depending on whether the argument is a
          -- parameter of the innermost nesting.
          (Nesting Names
_ LoopNesting
inner_nesting, [Nesting]
_) <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest
          let w :: SubExp
w = LoopNesting -> SubExp
loopNestingWidth LoopNesting
inner_nesting
              aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
inner_nesting
              inps :: [(Param Type, VName)]
inps = LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
inner_nesting

              remnantStm :: PatElem Type -> SubExpRes -> Stm rep
remnantStm PatElem Type
pe (SubExpRes Certs
cs (Var VName
v))
                | Just (Param Type
_, VName
arr) <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param Type, VName)]
inps =
                    forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
              remnantStm PatElem Type
pe (SubExpRes Certs
cs SubExp
se) =
                forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se

              stms :: Stms rep
stms =
                forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem Type -> SubExpRes -> Stm rep
remnantStm (forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) Result
res

          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DistAcc rep
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Stms rep
distStms = Stms rep
stms}

mapNesting ::
  (MonadFreshNames m, DistRep rep) =>
  Pat Type ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  DistNestT rep m (DistAcc rep) ->
  DistNestT rep m (DistAcc rep)
mapNesting :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
mapNesting Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs DistNestT rep m (DistAcc rep)
m =
  forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local DistEnv rep m -> DistEnv rep m
extend forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
leavingNesting forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistNestT rep m (DistAcc rep)
m
  where
    nest :: Nesting
nest =
      Names -> LoopNesting -> Nesting
Nesting forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
        Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w forall a b. (a -> b) -> a -> b
$
          forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
    extend :: DistEnv rep m -> DistEnv rep m
extend DistEnv rep m
env =
      DistEnv rep m
env
        { distNest :: Nestings
distNest = Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nest forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env,
          distScope :: Scope rep
distScope = forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env
        }

inNesting ::
  (Monad m, DistRep rep) =>
  KernelNest ->
  DistNestT rep m a ->
  DistNestT rep m a
inNesting :: forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
KernelNest -> DistNestT rep m a -> DistNestT rep m a
inNesting (LoopNesting
outer, [LoopNesting]
nests) = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
  DistEnv rep m
env
    { distNest :: Nestings
distNest = (Nesting
inner, [Nesting]
nests'),
      distScope :: Scope rep
distScope = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {k} (rep :: k).
(LParamInfo rep ~ Type) =>
LoopNesting -> Scope rep
scopeOfLoopNesting (LoopNesting
outer forall a. a -> [a] -> [a]
: [LoopNesting]
nests) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env
    }
  where
    (Nesting
inner, [Nesting]
nests') =
      case forall a. [a] -> [a]
reverse [LoopNesting]
nests of
        [] -> (LoopNesting -> Nesting
asNesting LoopNesting
outer, [])
        (LoopNesting
inner' : [LoopNesting]
ns) -> (LoopNesting -> Nesting
asNesting LoopNesting
inner', forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Nesting
asNesting forall a b. (a -> b) -> a -> b
$ LoopNesting
outer forall a. a -> [a] -> [a]
: forall a. [a] -> [a]
reverse [LoopNesting]
ns)
    asNesting :: LoopNesting -> Nesting
asNesting = Names -> LoopNesting -> Nesting
Nesting forall a. Monoid a => a
mempty

bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
isParallelStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms
  where
    isParallelStm :: Stm SOACS -> Bool
isParallelStm Stm SOACS
stm =
      Exp SOACS -> Bool
isMap (forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm SOACS
stm)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm))
    isMap :: Exp SOACS -> Bool
isMap BasicOp {} = Bool
False
    isMap Apply {} = Bool
False
    isMap Match {} = Bool
False
    isMap (DoLoop [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
body) = Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body
    isMap (DoLoop [(FParam SOACS, SubExp)]
_ WhileLoop {} Body SOACS
_) = Bool
False
    isMap (WithAcc [WithAccInput SOACS]
_ Lambda SOACS
lam) = Body SOACS -> Bool
bodyContainsParallelism forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
    isMap Op {} = Bool
True

lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism = Body SOACS -> Bool
bodyContainsParallelism forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody

distributeMapBodyStms ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  Stms SOACS ->
  DistNestT rep m (DistAcc rep)
distributeMapBodyStms :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
orig_acc = forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {rep} {m :: * -> *}.
(BodyDec rep ~ (), ExpDec rep ~ (), LetDec rep ~ Type,
 MonadFreshNames m, Buildable rep, HasSegOp rep, BuilderOps rep,
 CanBeAliased (Op rep), LocalScope rep m) =>
DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
orig_acc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList
  where
    onStms :: DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc
    onStms DistAcc rep
acc (Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda SOACS
lam)) : [Stm SOACS]
stms) = do
      Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
      Stms SOACS
stream_stms <-
        forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec SOACS)
pat SubExp
w [SubExp]
accs Lambda SOACS
lam [VName]
arrs) Scope SOACS
types
      Stms SOACS
stream_stms' <-
        forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep -> Scope rep -> Stms rep -> m (Stms rep)
copyPropagateInStms SimpleOps SOACS
simpleSOACS Scope SOACS
types Stms SOACS
stream_stms) Scope SOACS
types
      DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs) Stms SOACS
stream_stms') forall a. [a] -> [a] -> [a]
++ [Stm SOACS]
stms
    onStms DistAcc rep
acc (Stm SOACS
stm : [Stm SOACS]
stms) =
      -- It is important that stm is in scope if 'maybeDistributeStm'
      -- wants to distribute, even if this causes the slightly silly
      -- situation that stm is in scope of itself.
      forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
Stm SOACS -> DistNestT rep m a -> DistNestT rep m a
withStm Stm SOACS
stm forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
stm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc [Stm SOACS]
stms

onInnerMap :: Monad m => MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap MapLoop
loop DistAcc rep
acc = do
  MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
f <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distOnInnerMap
  MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
f MapLoop
loop DistAcc rep
acc

onTopLevelStms :: Monad m => Stms SOACS -> DistNestT rep m ()
onTopLevelStms :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms = do
  Stms SOACS -> DistNestT rep m (Stms rep)
f <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
distOnTopLevelStms
  forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> DistNestT rep m (Stms rep)
f Stms SOACS
stms

maybeDistributeStm ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  Stm SOACS ->
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
maybeDistributeStm :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
stm DistAcc rep
acc
  | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) =
      forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) DistAcc rep
acc
  | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux))
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec SOACS)
pat Op SOACS
soac)
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just Lambda SOACS
lam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form =
      -- Only distribute inside the map if we can distribute everything
      -- following the map.
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe (DistAcc rep)
Nothing -> forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
        Just DistAcc rep
acc' -> forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (m :: * -> *) (rep :: k).
Monad m =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap (Pat Type
-> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> MapLoop
MapLoop Pat (LetDec SOACS)
pat (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) SubExp
w Lambda SOACS
lam [VName]
arrs) DistAcc rep
acc'
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
merge form :: LoopForm SOACS
form@ForLoop {} Body SOACS
body)) DistAcc rep
acc
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` forall a. FreeIn a => a -> Names
freeIn Pat (LetDec SOACS)
pat) (forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat),
    Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | -- XXX: We cannot distribute if this loop depends on
            -- certificates bound within the loop nest (well, we could,
            -- but interchange would not be valid).  This is not a
            -- fundamental restriction, but an artifact of our
            -- certificate representation, which we should probably
            -- rethink.
            Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
              (forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn StmAux (ExpDec SOACS)
aux)
                Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
            Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
                forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
                KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs

                -- Simplification is key to hoisting out statements that
                -- were variant to the loop, but invariant to the outer maps
                -- (which are now innermost).
                Stms SOACS
stms <-
                  (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) forall a b. (a -> b) -> a -> b
$
                    forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest' ([Int]
-> Pat Type
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pat (LetDec SOACS)
pat [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body)
                forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms
                forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret)) DistAcc rep
acc
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` forall a. FreeIn a => a -> Names
freeIn Pat (LetDec SOACS)
pat) (forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat),
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Body SOACS -> Bool
bodyContainsParallelism (Body SOACS
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body SOACS)]
cases)
      Bool -> Bool -> Bool
|| Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType (forall rt. MatchDec rt -> [rt]
matchReturns MatchDec (BranchType SOACS)
ret)) =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
              (forall a. FreeIn a => a -> Names
freeIn [SubExp]
cond forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn MatchDec (BranchType SOACS)
ret) Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
            Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
                KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
                Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
                let branch :: Branch
branch = [Int]
-> Pat Type
-> [SubExp]
-> [Case (Body SOACS)]
-> Body SOACS
-> MatchDec (BranchType SOACS)
-> Branch
Branch [Int]
perm Pat (LetDec SOACS)
pat [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret
                Stms SOACS
stms <-
                  (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) forall a b. (a -> b) -> a -> b
$
                    forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stm SOACS)
interchangeBranch KernelNest
nest' Branch
branch
                forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms
                forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) DistAcc rep
acc
  | Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
              forall a. FreeIn a => a -> Names
freeIn (forall a. Int -> [a] -> [a]
drop Int
num_accs (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam))
                Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
            Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
                KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
                forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
                let withacc :: WithAccStm
withacc = [Int]
-> Pat Type -> [WithAccInput SOACS] -> Lambda SOACS -> WithAccStm
WithAccStm [Int]
perm Pat (LetDec SOACS)
pat [WithAccInput SOACS]
inputs Lambda SOACS
lam
                Stms SOACS
stms <-
                  (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) forall a b. (a -> b) -> a -> b
$
                    forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
KernelNest -> WithAccStm -> m (Stm SOACS)
interchangeWithAcc KernelNest
nest' WithAccStm
withacc
                forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms
                forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
  where
    num_accs :: Int
num_accs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just [Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes] <- forall {k} (rep :: k). ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    Just BuilderT SOACS (DistNestT rep m) ()
m <- forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat (LetDec SOACS)
pat SubExp
w Commutativity
comm Lambda SOACS
lam forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
      Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
      (()
_, Stms SOACS
stms) <- forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec SOACS)
aux BuilderT SOACS (DistNestT rep m) ()
m) Scope SOACS
types
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc Stms SOACS
stms

-- Parallelise segmented scatters.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(Shape, Int, VName)]
as))) DistAcc rep
acc =
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
      | Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
          forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
            Lambda rep
lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam
            forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
            forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Pat Type
-> Certs
-> SubExp
-> Lambda rep
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT rep m (Stms rep)
segmentedScatterKernel KernelNest
nest' [Int]
perm Pat (LetDec SOACS)
pat Certs
cs SubExp
w Lambda rep
lam' [VName]
ivs [(Shape, Int, VName)]
as
            forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
      forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- Parallelise segmented Hist.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Hist SubExp
w [VName]
as [HistOp SOACS]
ops Lambda SOACS
lam))) DistAcc rep
acc =
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
      | Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
          forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
            Lambda rep
lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam
            KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
            forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
            forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda rep
-> [VName]
-> DistNestT rep m (Stms rep)
segmentedHistKernel KernelNest
nest' [Int]
perm Certs
cs SubExp
w [HistOp SOACS]
ops Lambda rep
lam' [VName]
as
            forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
      forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- Parallelise Index slices if the result is going to be returned
-- directly from the kernel.  This is because we would otherwise have
-- to sequentialise writing the result, which may be costly.
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pat [PatElem (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Index VName
arr Slice SubExp
slice))) DistAcc rep
acc
  | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice,
    VName -> SubExp
Var (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec SOACS)
pe) forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall a b. (a, b) -> b
snd (Targets -> (Pat Type, Result)
innerTarget (forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc))) =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
_res, KernelNest
nest, DistAcc rep
acc') ->
          forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
            forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
            forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
segmentedGatherKernel KernelNest
nest (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice
            forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- If the scan can be distributed by itself, we will turn it into a
-- segmented scan.
--
-- If the scan cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
    Scan Lambda SOACS
lam [SubExp]
nes <- forall {k} (rep :: k). Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
                KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                Lambda rep
map_lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
map_lam
                forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$
                  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Lambda SOACS
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel KernelNest
nest' [Int]
perm Certs
cs SubExp
w Lambda SOACS
lam Lambda rep
map_lam' [SubExp]
nes [VName]
arrs
                    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot forall a. Monoid a => a
mempty Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- If the map function of the reduction contains parallelism we split
-- it, so that the parallelism can be exploited.
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
    Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam = do
      (Stm SOACS
mapstm, Stm SOACS
redstm) <-
        forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pat (LetDec SOACS)
pat (SubExp
w, [Reduce SOACS]
reds, Lambda SOACS
map_lam, [VName]
arrs)
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
mapstm {stmAux :: StmAux (ExpDec SOACS)
stmAux = StmAux (ExpDec SOACS)
aux} forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
redstm
-- if the reduction can be distributed by itself, we will turn it into a
-- segmented reduce.
--
-- If the reduction cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
    Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes <- forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce SOACS]
reds =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
                KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest

                Lambda rep
lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam
                Lambda rep
map_lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
map_lam

                let comm' :: Commutativity
comm'
                      | forall {k} (rep :: k). Lambda rep -> Bool
commutativeLambda Lambda SOACS
lam = Commutativity
Commutative
                      | Bool
otherwise = Commutativity
comm

                forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Commutativity
-> Lambda rep
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel KernelNest
nest' [Int]
perm Certs
cs SubExp
w Commutativity
comm' Lambda rep
lam' Lambda rep
map_lam' [SubExp]
nes [VName]
arrs
                  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot forall a. Monoid a => a
mempty Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm (Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc = do
  -- This Screma is too complicated for us to immediately do
  -- anything, so split it up and try again.
  Scope SOACS
scope <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat (LetDec SOACS)
pat SubExp
w ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp (Replicate (Shape (SubExp
d : [SubExp]
ds)) SubExp
v))) DistAcc rep
acc
  | [Type
t] <- forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec SOACS)
pat = do
      VName
tmp <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tmp"
      let rowt :: Type
rowt = forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t
          newstm :: Stm SOACS
newstm = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
d [] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
          tmpstm :: Stm SOACS
tmpstm =
            forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
tmp Type
rowt]) StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
ds) SubExp
v
          lam :: Lambda SOACS
lam =
            Lambda
              { lambdaReturnType :: [Type]
lambdaReturnType = [Type
rowt],
                lambdaParams :: [LParam SOACS]
lambdaParams = [],
                lambdaBody :: Body SOACS
lambdaBody = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody (forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
tmpstm) [VName -> SubExpRes
varRes VName
tmp]
              }
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
newstm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Copy VName
stm_arr))) DistAcc rep
acc =
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
_ Pat Type
outerpat VName
arr ->
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
-- Opaques are applied to the full array, because otherwise they can
-- drastically inhibit parallelisation in some cases.
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pat [PatElem (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Opaque OpaqueOp
_ (Var VName
stm_arr)))) DistAcc rep
acc
  | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf PatElem (LetDec SOACS)
pe =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
_ Pat Type
outerpat VName
arr ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rearrange [Int]
perm VName
stm_arr))) DistAcc rep
acc =
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pat Type
outerpat VName
arr -> do
    let r :: Int
r = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a b. (a, b) -> b
snd KernelNest
nest) forall a. Num a => a -> a -> a
+ Int
1
        perm' :: [Int]
perm' = [Int
0 .. Int
r forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+ Int
r) [Int]
perm
    -- We need to add a copy, because the original map nest
    -- will have produced an array without aliases, and so must we.
    VName
arr' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
arr
    Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList
        [ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
arr' Type
arr_t]) StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr,
          forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm' VName
arr'
        ]
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Reshape ReshapeKind
k Shape
reshape VName
stm_arr))) DistAcc rep
acc =
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pat Type
outerpat VName
arr -> do
    let reshape' :: Shape
reshape' = forall d. [d] -> ShapeBase d
Shape (KernelNest -> [SubExp]
kernelNestWidths KernelNest
nest) forall a. Semigroup a => a -> a -> a
<> Shape
reshape
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
reshape' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rotate [SubExp]
rots VName
stm_arr))) DistAcc rep
acc =
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pat Type
outerpat VName
arr -> do
    let rots' :: [SubExp]
rots' = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (KernelNest -> [SubExp]
kernelNestWidths KernelNest
nest) forall a. [a] -> [a] -> [a]
++ [SubExp]
rots
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp (Update Safety
_ VName
arr Slice SubExp
slice (Var VName
v)))) DistAcc rep
acc
  | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm),
            Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res -> do
              forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
              forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
                KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm
                  forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DistNestT rep m (Stms rep)
segmentedUpdateKernel KernelNest
nest' [Int]
perm (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice VName
v
                forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ -> forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Concat Int
d (VName
x :| [VName]
xs) SubExp
w))) DistAcc rep
acc =
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
_, KernelNest
nest, DistAcc rep
acc') ->
      forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$
        KernelNest -> DistNestT rep m (Maybe (Stms rep))
segmentedConcat KernelNest
nest
          forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
      forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
  where
    segmentedConcat :: KernelNest -> DistNestT rep m (Maybe (Stms rep))
segmentedConcat KernelNest
nest =
      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int
0] forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty [] (VName
x forall a. a -> [a] -> [a]
: [VName]
xs) forall a b. (a -> b) -> a -> b
$
        \Pat Type
pat [(VName, SubExp)]
_ [KernelInput]
_ [SubExp]
_ (VName
x' : [VName]
xs') ->
          let d' :: Int
d' = Int
d forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a b. (a, b) -> b
snd KernelNest
nest) forall a. Num a => a -> a -> a
+ Int
1
           in forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
d' (VName
x' forall a. a -> [a] -> NonEmpty a
:| [VName]
xs') SubExp
w
maybeDistributeStm Stm SOACS
stm DistAcc rep
acc =
  forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc

distributeSingleUnaryStm ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  Stm SOACS ->
  VName ->
  (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)) ->
  DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)
f =
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
      | forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm),
        (LoopNesting
outer, [LoopNesting]
_) <- KernelNest
nest,
        [(Param Type
arr_p, VName
arr)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        KernelNest -> Names
boundInKernelNest KernelNest
nest
          Names -> Names -> Names
`namesIntersection` forall a. FreeIn a => a -> Names
freeIn Stm SOACS
stm
          forall a. Eq a => a -> a -> Bool
== VName -> Names
oneName (forall dec. Param dec -> VName
paramName Param Type
arr_p),
        VName -> KernelNest -> Bool
perfectlyMapped VName
arr KernelNest
nest -> do
          forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
          let outerpat :: Pat Type
outerpat = LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst KernelNest
nest
          forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
            forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)
f KernelNest
nest Pat Type
outerpat VName
arr
            forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ -> forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
  where
    perfectlyMapped :: VName -> KernelNest -> Bool
perfectlyMapped VName
arr (LoopNesting
outer, [LoopNesting]
nest)
      | [(Param Type
p, VName
arr')] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        VName
arr forall a. Eq a => a -> a -> Bool
== VName
arr' =
          case [LoopNesting]
nest of
            [] -> forall dec. Param dec -> VName
paramName Param Type
p forall a. Eq a => a -> a -> Bool
== VName
stm_arr
            LoopNesting
x : [LoopNesting]
xs -> VName -> KernelNest -> Bool
perfectlyMapped (forall dec. Param dec -> VName
paramName Param Type
p) (LoopNesting
x, [LoopNesting]
xs)
      | Bool
otherwise =
          Bool
False

distribute ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
distribute :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute DistAcc rep
acc =
  forall a. a -> Maybe a -> a
fromMaybe DistAcc rep
acc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc

mkSegLevel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel = do
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ \[SubExp]
w [Char]
desc ThreadRecommendation
r -> do
    (SegOpLevel rep
lvl, Stms rep
stms) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall a b. (a -> b) -> a -> b
$ forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl [SubExp]
w [Char]
desc ThreadRecommendation
r
    forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
stms
    forall (f :: * -> *) a. Applicative f => a -> f a
pure SegOpLevel rep
lvl

distributeIfPossible ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc = do
  Nestings
nest <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
  forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m,
 MonadLogger m) =>
MkSegLevel rep m
-> Nestings -> Targets -> Stms rep -> m (Maybe (Targets, Stms rep))
tryDistribute [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl Nestings
nest (forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc) (forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms rep)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    Just (Targets
targets, Stms rep
kernel) -> do
      forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm Stms rep
kernel
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall a. a -> Maybe a
Just
          DistAcc
            { distTargets :: Targets
distTargets = Targets
targets,
              distStms :: Stms rep
distStms = forall a. Monoid a => a
mempty
            }

distributeSingleStm ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  Stm SOACS ->
  DistNestT
    rep
    m
    ( Maybe
        ( PostStms rep,
          Result,
          KernelNest,
          DistAcc rep
        )
    )
distributeSingleStm :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm = do
  Nestings
nest <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
  forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m,
 MonadLogger m) =>
MkSegLevel rep m
-> Nestings -> Targets -> Stms rep -> m (Maybe (Targets, Stms rep))
tryDistribute [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl Nestings
nest (forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc) (forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms rep)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    Just (Targets
targets, Stms rep
distributed_stms) ->
      forall {k1} {k2} (m :: * -> *) (t :: k1) (rep :: k2).
(MonadFreshNames m, HasScope t m, ASTRep rep) =>
Nestings
-> Targets -> Stm rep -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe (Result, Targets, KernelNest)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        Just (Result
res, Targets
targets', KernelNest
new_kernel_nest) ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
            forall a. a -> Maybe a
Just
              ( forall {k} (rep :: k). Stms rep -> PostStms rep
PostStms Stms rep
distributed_stms,
                Result
res,
                KernelNest
new_kernel_nest,
                DistAcc
                  { distTargets :: Targets
distTargets = Targets
targets',
                    distStms :: Stms rep
distStms = forall a. Monoid a => a
mempty
                  }
              )

segmentedScatterKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Pat Type ->
  Certs ->
  SubExp ->
  Lambda rep ->
  [VName] ->
  [(Shape, Int, VName)] ->
  DistNestT rep m (Stms rep)
segmentedScatterKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Pat Type
-> Certs
-> SubExp
-> Lambda rep
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT rep m (Stms rep)
segmentedScatterKernel KernelNest
nest [Int]
perm Pat Type
scatter_pat Certs
cs SubExp
scatter_w Lambda rep
lam [VName]
ivs [(Shape, Int, VName)]
dests = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a scatter is not a reduction or
  -- scan.
  --
  -- First, pretend that the scatter is also part of the nesting.  The
  -- KernelNest we produce here is technically not sensible, but it's
  -- good enough for flatKernel to work.
  let nesting :: LoopNesting
nesting =
        Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
scatter_pat (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs forall a. Monoid a => a
mempty ()) SubExp
scatter_w forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [VName]
ivs
      nest' :: KernelNest
nest' =
        (Pat Type, Result) -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (Pat Type
scatter_pat, forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam) LoopNesting
nesting KernelNest
nest
  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest'

  let ([Shape]
as_ws, [Int]
as_ns, [VName]
as) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
      indexes :: [Int]
indexes = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
as_ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws

  -- The input/output arrays ('as') _must_ correspond to some kernel
  -- input, or else the original nested scatter would have been
  -- ill-typed.  Find them.
  [KernelInput]
as_inps <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {f :: * -> *} {t :: * -> *}.
(Applicative f, Foldable t) =>
t KernelInput -> VName -> f KernelInput
findInput [KernelInput]
kernel_inps) [VName]
as

  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel

  let rts :: [Type]
rts =
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall a. Int -> [a] -> [a]
take Int
1) forall a b. (a -> b) -> a -> b
$
          forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns forall a b. (a -> b) -> a -> b
$
            forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
indexes) forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
      (Result
is, Result
vs) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
indexes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam

  (Result
is', Stms rep
k_body_stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
is

  let k_body :: KernelBody rep
k_body =
        forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
as_ws [Int]
as_ns [KernelInput]
as_inps) (Result
is' forall a. [a] -> [a] -> [a]
++ Result
vs)
          forall a b. a -> (a -> b) -> b
& forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)]
-> (Shape, KernelInput, [(Result, SubExpRes)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace)
          forall a b. a -> (a -> b) -> b
& forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
k_body_stms
      -- Remove unused kernel inputs, since some of these might
      -- reference the array we are scattering into.
      kernel_inps' :: [KernelInput]
kernel_inps' =
        forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn KernelBody rep
k_body) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps

  (SegOp (SegOpLevel rep) rep
k, Stms rep
k_stms) <- forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps' [Type]
rts KernelBody rep
k_body

  forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
k_stms

    let pat :: Pat Type
pat =
          forall dec. [PatElem dec] -> Pat dec
Pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$
            forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$
              LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$
                forall a b. (a, b) -> a
fst KernelNest
nest

    forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
k
  where
    findInput :: t KernelInput -> VName -> f KernelInput
findInput t KernelInput
kernel_inps VName
a =
      forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall {a}. a
bad forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
    bad :: a
bad = forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested scatter encountered."

    inPlaceReturn :: [(VName, SubExp)]
-> (Shape, KernelInput, [(Result, SubExpRes)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace (Shape
aw, KernelInput
inp, [(Result, SubExpRes)]
is_vs) =
      Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
        ( forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
            forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
        )
        (forall d. [d] -> ShapeBase d
Shape (forall a. [a] -> [a]
init [SubExp]
ws forall a. [a] -> [a] -> [a]
++ forall d. ShapeBase d -> [d]
shapeDims Shape
aw))
        (KernelInput -> VName
kernelInputArray KernelInput
inp)
        [ (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (forall a. [a] -> [a]
init [VName]
gtids) forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v)
          | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs
        ]
      where
        ([VName]
gtids, [SubExp]
ws) = forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
ispace

segmentedUpdateKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  VName ->
  Slice SubExp ->
  VName ->
  DistNestT rep m (Stms rep)
segmentedUpdateKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DistNestT rep m (Stms rep)
segmentedUpdateKernel KernelNest
nest [Int]
perm Certs
cs VName
arr Slice SubExp
slice VName
v = do
  ([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let slice_dims :: [SubExp]
slice_dims = forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  [VName]
slice_gtids <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice_dims) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")

  let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids [SubExp]
slice_dims

  ((Type
res_t, KernelResult
res), Stms rep
kstms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
    -- Compute indexes into full array.
    SubExp
v' <-
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"v" forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
v forall a b. (a -> b) -> a -> b
$
              forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
                forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
    [SubExp]
slice_is <-
      forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") forall a b. (a -> b) -> a -> b
$
        forall d. Num d => Slice d -> [d] -> [d]
fixSlice (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids

    let write_is :: [SubExp]
write_is = forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(VName, SubExp)]
base_ispace forall a. [a] -> [a] -> [a]
++ [SubExp]
slice_is
        arr' :: VName
arr' =
          forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a. HasCallStack => [Char] -> a
error [Char]
"incorrectly typed Update") KernelInput -> VName
kernelInputArray forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
arr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
    Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr'
    Type
v_t <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( Type
v_t,
        Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns forall a. Monoid a => a
mempty (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t) VName
arr' [(forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
write_is, SubExp
v')]
      )

  -- Remove unused kernel inputs, since some of these might
  -- reference the array we are scattering into.
  let kernel_inps' :: [KernelInput]
kernel_inps' =
        forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` (forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn KernelResult
res)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps

  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
  (SegOp (SegOpLevel rep) rep
k, Stms rep
prestms) <-
    forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps' [Type
res_t] forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
kstms [KernelResult
res]

  forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
prestms

    let pat :: Pat Type
pat =
          forall dec. [PatElem dec] -> Pat dec
Pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$
            forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$
              LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$
                forall a b. (a, b) -> a
fst KernelNest
nest

    forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
k

segmentedGatherKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  Certs ->
  VName ->
  Slice SubExp ->
  DistNestT rep m (Stms rep)
segmentedGatherKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
segmentedGatherKernel KernelNest
nest Certs
cs VName
arr Slice SubExp
slice = do
  let slice_dims :: [SubExp]
slice_dims = forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  [VName]
slice_gtids <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice_dims) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")

  ([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids [SubExp]
slice_dims

  ((Type
res_t, KernelResult
res), Stms rep
kstms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
    -- Compute indexes into full array.
    Slice SubExp
slice'' <-
      forall (m :: * -> *).
MonadBuilder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
slice) forall a b. (a -> b) -> a -> b
$
        Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice forall a b. (a -> b) -> a -> b
$
          forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
            forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
    SubExp
v' <- forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"v" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice''
    Type
v_t <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
v_t, ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify forall a. Monoid a => a
mempty SubExp
v')

  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
  (SegOp (SegOpLevel rep) rep
k, Stms rep
prestms) <-
    forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
kstms [KernelResult
res]

  forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
prestms

    let pat :: Pat Type
pat = forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst KernelNest
nest

    forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
k

segmentedHistKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda rep ->
  [VName] ->
  DistNestT rep m (Stms rep)
segmentedHistKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda rep
-> [VName]
-> DistNestT rep m (Stms rep)
segmentedHistKernel KernelNest
nest [Int]
perm Certs
cs SubExp
hist_w [HistOp SOACS]
ops Lambda rep
lam [VName]
arrs = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a Hist is not a reduction or
  -- scan.
  ([(VName, SubExp)]
ispace, [KernelInput]
inputs) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let orig_pat :: Pat Type
orig_pat =
        forall dec. [PatElem dec] -> Pat dec
Pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$
          forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$
            LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$
              forall a b. (a, b) -> a
fst KernelNest
nest

  -- The input/output arrays _must_ correspond to some kernel input,
  -- or else the original nested Hist would have been ill-typed.
  -- Find them.
  [HistOp SOACS]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) ->
    forall {k} (rep :: k).
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
SOACS.HistOp Shape
num_bins SubExp
rf
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KernelInput -> VName
kernelInputArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {f :: * -> *} {t :: * -> *}.
(Applicative f, Foldable t) =>
t KernelInput -> VName -> f KernelInput
findInput [KernelInput]
inputs) [VName]
dests
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
op

  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel
  Lambda SOACS -> Builder rep (Lambda rep)
onLambda <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda
  let onLambda' :: Lambda SOACS -> BuilderT rep m (Lambda rep)
onLambda' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Builder rep (Lambda rep)
onLambda
  forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall a b. (a -> b) -> a -> b
$
    forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ forall a b. (a -> b) -> a -> b
$ do
      -- It is important not to launch unnecessarily many threads for
      -- histograms, because it may mean we unnecessarily need to reduce
      -- subhistograms as well.
      SegOpLevel rep
lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
hist_w forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"seghist" forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
(Lambda SOACS -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda SOACS -> BuilderT rep m (Lambda rep)
onLambda' SegOpLevel rep
lvl Pat Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certs
cs SubExp
hist_w [HistOp SOACS]
ops' Lambda rep
lam [VName]
arrs
  where
    findInput :: t KernelInput -> VName -> f KernelInput
findInput t KernelInput
kernel_inps VName
a =
      forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall {a}. a
bad forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
    bad :: a
bad = forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested Hist encountered."

histKernel ::
  (MonadBuilder m, DistRep (Rep m)) =>
  (Lambda SOACS -> m (Lambda (Rep m))) ->
  SegOpLevel (Rep m) ->
  Pat Type ->
  [(VName, SubExp)] ->
  [KernelInput] ->
  Certs ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda (Rep m) ->
  [VName] ->
  m (Stms (Rep m))
histKernel :: forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
(Lambda SOACS -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda SOACS -> m (Lambda (Rep m))
onLambda SegOpLevel (Rep m)
lvl Pat Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certs
cs SubExp
hist_w [HistOp SOACS]
ops Lambda (Rep m)
lam [VName]
arrs = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ forall a b. (a -> b) -> a -> b
$ do
  [HistOp (Rep m)]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
    (Lambda SOACS
op', [SubExp]
nes', Shape
shape) <- forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
    Lambda (Rep m)
op'' <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> m (Lambda (Rep m))
onLambda Lambda SOACS
op'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda (Rep m)
op''

  let isDest :: VName -> Bool
isDest = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). HistOp rep -> [VName]
histDest [HistOp (Rep m)]
ops'
      inputs' :: [KernelInput]
inputs' = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isDest forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputArray) [KernelInput]
inputs

  forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegOpLevel (Rep m)
lvl Pat Type
orig_pat SubExp
hist_w [(VName, SubExp)]
ispace [KernelInput]
inputs' [HistOp (Rep m)]
ops' Lambda (Rep m)
lam [VName]
arrs

determineReduceOp ::
  MonadBuilder m =>
  Lambda SOACS ->
  [SubExp] ->
  m (Lambda SOACS, [SubExp], Shape)
determineReduceOp :: forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes =
  -- FIXME? We are assuming that the accumulator is a replicate, and
  -- we fish out its value in a gross way.
  case forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes of
    Just [VName]
ne_vs' -> do
      let (Shape
shape, Lambda SOACS
lam') = Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap Lambda SOACS
lam
      [SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ne_vs' forall a b. (a -> b) -> a -> b
$ \VName
ne_v -> do
        Type
ne_v_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
ne_v
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"hist_ne" forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
ne_v forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
ne_v_t forall a b. (a -> b) -> a -> b
$
                forall a. Int -> a -> [a]
replicate (forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) forall a b. (a -> b) -> a -> b
$
                  forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$
                    IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS
lam', [SubExp]
nes', Shape
shape)
    Maybe [VName]
Nothing ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS
lam, [SubExp]
nes, forall a. Monoid a => a
mempty)

isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap Lambda SOACS
lam
  | [Let (Pat [PatElem (LetDec SOACS)]
pes) StmAux (ExpDec SOACS)
_ (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))] <-
      forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam,
    forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall {k} (rep :: k). Body rep -> Result
bodyResult (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)) forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec SOACS)]
pes,
    Just Lambda SOACS
map_lam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
    [VName]
arrs forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) =
      let (Shape
shape, Lambda SOACS
lam') = Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap Lambda SOACS
map_lam
       in (forall d. [d] -> ShapeBase d
Shape [SubExp
w] forall a. Semigroup a => a -> a -> a
<> Shape
shape, Lambda SOACS
lam')
  | Bool
otherwise = (forall a. Monoid a => a
mempty, Lambda SOACS
lam)

segmentedScanomapKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  SubExp ->
  Lambda SOACS ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Lambda SOACS
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel KernelNest
nest [Int]
perm Certs
cs SubExp
segment_size Lambda SOACS
lam Lambda rep
map_lam [SubExp]
nes [VName]
arrs = do
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel
  Lambda SOACS -> Builder rep (Lambda rep)
onLambda <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda
  let onLambda' :: Lambda SOACS -> BuilderT rep m (Lambda rep)
onLambda' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Builder rep (Lambda rep)
onLambda
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm (forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
lam) (forall a. FreeIn a => a -> Names
freeIn Lambda rep
map_lam) [SubExp]
nes [] forall a b. (a -> b) -> a -> b
$
    \Pat Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps [SubExp]
nes' [VName]
_ -> do
      (Lambda SOACS
lam', [SubExp]
nes'', Shape
shape) <- forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes'
      Lambda rep
lam'' <- Lambda SOACS -> BuilderT rep m (Lambda rep)
onLambda' Lambda SOACS
lam'
      let scan_op :: SegBinOp rep
scan_op = forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda rep
lam'' [SubExp]
nes'' Shape
shape
      SegOpLevel rep
lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
segment_size forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segscan" forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel rep
lvl Pat Type
pat Certs
cs SubExp
segment_size [SegBinOp rep
scan_op] Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

regularSegmentedRedomapKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  SubExp ->
  Commutativity ->
  Lambda rep ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Commutativity
-> Lambda rep
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel KernelNest
nest [Int]
perm Certs
cs SubExp
segment_size Commutativity
comm Lambda rep
lam Lambda rep
map_lam [SubExp]
nes [VName]
arrs = do
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm (forall a. FreeIn a => a -> Names
freeIn Lambda rep
lam) (forall a. FreeIn a => a -> Names
freeIn Lambda rep
map_lam) [SubExp]
nes [] forall a b. (a -> b) -> a -> b
$
    \Pat Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps [SubExp]
nes' [VName]
_ -> do
      let red_op :: SegBinOp rep
red_op = forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes' forall a. Monoid a => a
mempty
      SegOpLevel rep
lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
segment_size forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segred" forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel rep
lvl Pat Type
pat Certs
cs SubExp
segment_size [SegBinOp rep
red_op] Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

isSegmentedOp ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Names ->
  Names ->
  [SubExp] ->
  [VName] ->
  ( Pat Type ->
    [(VName, SubExp)] ->
    [KernelInput] ->
    [SubExp] ->
    [VName] ->
    BuilderT rep m ()
  ) ->
  DistNestT rep m (Maybe (Stms rep))
isSegmentedOp :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm Names
free_in_op Names
_free_in_fold_op [SubExp]
nes [VName]
arrs Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ()
m = forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT forall a b. (a -> b) -> a -> b
$ do
  -- We must verify that array inputs to the operation are inputs to
  -- the outermost loop nesting or free in the loop nest.  Nothing
  -- free in the op may be bound by the nest.  Furthermore, the
  -- neutral elements must be free in the loop nest.
  --
  -- We must summarise any names from free_in_op that are bound in the
  -- nest, and describe how to obtain them given segment indices.

  let bound_by_nest :: Names
bound_by_nest = KernelNest -> Names
boundInKernelNest KernelNest
nest

  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Names
free_in_op Names -> Names -> Bool
`namesIntersect` Names
bound_by_nest) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Non-fold lambda uses nest-bound parameters."

  let indices :: [VName]
indices = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace

      prepareNe :: SubExp -> MaybeT (DistNestT rep m) SubExp
prepareNe (Var VName
v)
        | VName
v VName -> Names -> Bool
`nameIn` Names
bound_by_nest =
            forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Neutral element bound in nest"
      prepareNe SubExp
ne = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
ne

      prepareArr :: VName -> MaybeT (DistNestT rep m) (BuilderT rep m VName)
prepareArr VName
arr =
        case forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
arr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps of
          Just KernelInput
inp
            | KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
indices ->
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
          Maybe KernelInput
Nothing
            | VName
arr VName -> Names -> Bool
`notNameIn` Names
bound_by_nest ->
                -- This input is something that is free inside
                -- the loop nesting. We will have to replicate
                -- it.
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp
                    (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_repd")
                    (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
          Maybe KernelInput
_ ->
            forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Input not free, perfectly mapped, or outermost."

  [SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> MaybeT (DistNestT rep m) SubExp
prepareNe [SubExp]
nes

  [BuilderT rep m VName]
mk_arrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> MaybeT (DistNestT rep m) (BuilderT rep m VName)
prepareArr [VName]
arrs

  forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
    forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall a b. (a -> b) -> a -> b
$
      forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ forall a b. (a -> b) -> a -> b
$ do
        [VName]
nested_arrs <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [BuilderT rep m VName]
mk_arrs

        let pat :: Pat Type
pat =
              forall dec. [PatElem dec] -> Pat dec
Pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$
                forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$
                  LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$
                    forall a b. (a, b) -> a
fst KernelNest
nest

        Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ()
m Pat Type
pat [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [SubExp]
nes' [VName]
nested_arrs

permutationAndMissing :: Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing :: Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing (Pat [PatElem Type]
pes) Result
res = do
  let ([PatElem Type]
_used, [PatElem Type]
unused) =
        forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn Result
res) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes
      res' :: [SubExp]
res' = forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
      res_expanded :: [SubExp]
res_expanded = [SubExp]
res' forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
unused
  [Int]
perm <- forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` [SubExp]
res_expanded
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Int]
perm, [PatElem Type]
unused)

-- Add extra pattern elements to every kernel nesting level.
expandKernelNest ::
  MonadFreshNames m => [PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest :: forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pes (LoopNesting
outer_nest, [LoopNesting]
inner_nests) = do
  let outer_size :: [SubExp]
outer_size =
        LoopNesting -> SubExp
loopNestingWidth LoopNesting
outer_nest
          forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
      inner_sizes :: [[SubExp]]
inner_sizes = forall a. [a] -> [[a]]
tails forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
  LoopNesting
outer_nest' <- LoopNesting -> [SubExp] -> m LoopNesting
expandWith LoopNesting
outer_nest [SubExp]
outer_size
  [LoopNesting]
inner_nests' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM LoopNesting -> [SubExp] -> m LoopNesting
expandWith [LoopNesting]
inner_nests [[SubExp]]
inner_sizes
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (LoopNesting
outer_nest', [LoopNesting]
inner_nests')
  where
    expandWith :: LoopNesting -> [SubExp] -> m LoopNesting
expandWith LoopNesting
nest [SubExp]
dims = do
      [PatElem Type]
pes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *} {dec}.
(MonadFreshNames m, Typed dec) =>
[SubExp] -> PatElem dec -> m (PatElem Type)
expandPatElemWith [SubExp]
dims) [PatElem Type]
pes
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        LoopNesting
nest
          { loopNestingPat :: Pat Type
loopNestingPat =
              forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems (LoopNesting -> Pat Type
loopNestingPat LoopNesting
nest) forall a. Semigroup a => a -> a -> a
<> [PatElem Type]
pes'
          }

    expandPatElemWith :: [SubExp] -> PatElem dec -> m (PatElem Type)
expandPatElemWith [SubExp]
dims PatElem dec
pe = do
      VName
name <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem dec
pe
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        PatElem dec
pe
          { patElemName :: VName
patElemName = VName
name,
            patElemDec :: Type
patElemDec = forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe Type -> Shape -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape [SubExp]
dims
          }

kernelOrNot ::
  (MonadFreshNames m, DistRep rep) =>
  Certs ->
  Stm SOACS ->
  DistAcc rep ->
  PostStms rep ->
  DistAcc rep ->
  Maybe (Stms rep) ->
  DistNestT rep m (DistAcc rep)
kernelOrNot :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot Certs
cs Stm SOACS
stm DistAcc rep
acc PostStms rep
_ DistAcc rep
_ Maybe (Stms rep)
Nothing =
  forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs Stm SOACS
stm) DistAcc rep
acc
kernelOrNot Certs
cs Stm SOACS
_ DistAcc rep
_ PostStms rep
kernels DistAcc rep
acc' (Just Stms rep
stms) = do
  forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
  forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs) Stms rep
stms
  forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'

distributeMap ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  MapLoop ->
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
distributeMap :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap (MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) DistAcc rep
acc =
  forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
mapNesting
      Pat Type
pat
      StmAux ()
aux
      SubExp
w
      Lambda SOACS
lam
      [VName]
arrs
      (forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc' Stms SOACS
lam_stms)
  where
    acc' :: DistAcc rep
acc' =
      DistAcc
        { distTargets :: Targets
distTargets =
            (Pat Type, Result) -> Targets -> Targets
pushInnerTarget
              (Pat Type
pat, forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)
              forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc,
          distStms :: Stms rep
distStms = forall a. Monoid a => a
mempty
        }

    lam_stms :: Stms SOACS
lam_stms = forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam