{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.DistributeNests
  ( MapLoop (..),
    mapLoopStm,
    bodyContainsParallelism,
    lambdaContainsParallelism,
    determineReduceOp,
    histKernel,
    DistEnv (..),
    DistAcc (..),
    runDistNestT,
    DistNestT,
    liftInner,
    distributeMap,
    distribute,
    distributeSingleStm,
    distributeMapBodyStms,
    addStmsToAcc,
    addStmToAcc,
    permutationAndMissing,
    addPostStms,
    postStm,
    inNesting,
  )
where

import Control.Arrow (first)
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Trans.Maybe
import Control.Monad.Writer.Strict
import Data.List (find, partition, tails)
import qualified Data.Map as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.SOACS (SOACS)
import qualified Futhark.IR.SOACS as SOACS
import Futhark.IR.SOACS.SOAC hiding (HistOp, histDest)
import Futhark.IR.SOACS.Simplify (simpleSOACS, simplifyStms)
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Tools
import Futhark.Transform.CopyPropagate
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename
import Futhark.Util
import Futhark.Util.Log

scopeForSOACs :: SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs :: Scope lore -> Scope SOACS
scopeForSOACs = Scope lore -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope

data MapLoop = MapLoop SOACS.Pattern (StmAux ()) SubExp SOACS.Lambda [VName]

mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm (MapLoop Pattern
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) =
  Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam) [VName]
arrs

data DistEnv lore m = DistEnv
  { DistEnv lore m -> Nestings
distNest :: Nestings,
    DistEnv lore m -> Scope lore
distScope :: Scope lore,
    DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
distOnTopLevelStms :: Stms SOACS -> DistNestT lore m (Stms lore),
    DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distOnInnerMap ::
      MapLoop ->
      DistAcc lore ->
      DistNestT lore m (DistAcc lore),
    DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
distOnSOACSStms :: Stm SOACS -> Binder lore (Stms lore),
    DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda :: Lambda SOACS -> Binder lore (Lambda lore),
    DistEnv lore m -> MkSegLevel lore m
distSegLevel :: MkSegLevel lore m
  }

data DistAcc lore = DistAcc
  { DistAcc lore -> Targets
distTargets :: Targets,
    DistAcc lore -> Stms lore
distStms :: Stms lore
  }

data DistRes lore = DistRes
  { DistRes lore -> PostStms lore
accPostStms :: PostStms lore,
    DistRes lore -> Log
accLog :: Log
  }

instance Semigroup (DistRes lore) where
  DistRes PostStms lore
ks1 Log
log1 <> :: DistRes lore -> DistRes lore -> DistRes lore
<> DistRes PostStms lore
ks2 Log
log2 =
    PostStms lore -> Log -> DistRes lore
forall lore. PostStms lore -> Log -> DistRes lore
DistRes (PostStms lore
ks1 PostStms lore -> PostStms lore -> PostStms lore
forall a. Semigroup a => a -> a -> a
<> PostStms lore
ks2) (Log
log1 Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log2)

instance Monoid (DistRes lore) where
  mempty :: DistRes lore
mempty = PostStms lore -> Log -> DistRes lore
forall lore. PostStms lore -> Log -> DistRes lore
DistRes PostStms lore
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

newtype PostStms lore = PostStms {PostStms lore -> Stms lore
unPostStms :: Stms lore}

instance Semigroup (PostStms lore) where
  PostStms Stms lore
xs <> :: PostStms lore -> PostStms lore -> PostStms lore
<> PostStms Stms lore
ys = Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms (Stms lore -> PostStms lore) -> Stms lore -> PostStms lore
forall a b. (a -> b) -> a -> b
$ Stms lore
ys Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
xs

instance Monoid (PostStms lore) where
  mempty :: PostStms lore
mempty = Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
forall a. Monoid a => a
mempty

typeEnvFromDistAcc :: DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc :: DistAcc lore -> Scope lore
typeEnvFromDistAcc = PatternT Type -> Scope lore
forall lore dec. (LetDec lore ~ dec) => PatternT dec -> Scope lore
scopeOfPattern (PatternT Type -> Scope lore)
-> (DistAcc lore -> PatternT Type) -> DistAcc lore -> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternT Type, Result) -> PatternT Type
forall a b. (a, b) -> a
fst ((PatternT Type, Result) -> PatternT Type)
-> (DistAcc lore -> (PatternT Type, Result))
-> DistAcc lore
-> PatternT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (PatternT Type, Result)
outerTarget (Targets -> (PatternT Type, Result))
-> (DistAcc lore -> Targets)
-> DistAcc lore
-> (PatternT Type, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets

addStmsToAcc :: Stms lore -> DistAcc lore -> DistAcc lore
addStmsToAcc :: Stms lore -> DistAcc lore -> DistAcc lore
addStmsToAcc Stms lore
stms DistAcc lore
acc =
  DistAcc lore
acc {distStms :: Stms lore
distStms = Stms lore
stms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc}

addStmToAcc ::
  (MonadFreshNames m, DistLore lore) =>
  Stm SOACS ->
  DistAcc lore ->
  DistNestT lore m (DistAcc lore)
addStmToAcc :: Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc = do
  Stm SOACS -> Binder lore (Stms lore)
onSoacs <- (DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore))
-> DistNestT lore m (Stm SOACS -> Binder lore (Stms lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
forall lore (m :: * -> *).
DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
distOnSOACSStms
  (Stms lore
stm', Stms lore
_) <- Binder lore (Stms lore) -> DistNestT lore m (Stms lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Stms lore)
 -> DistNestT lore m (Stms lore, Stms lore))
-> Binder lore (Stms lore)
-> DistNestT lore m (Stms lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Binder lore (Stms lore)
onSoacs Stm SOACS
stm
  DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc {distStms :: Stms lore
distStms = Stms lore
stm' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc}

soacsLambda ::
  (MonadFreshNames m, DistLore lore) =>
  Lambda SOACS ->
  DistNestT lore m (Lambda lore)
soacsLambda :: Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam = do
  Lambda -> Binder lore (Lambda lore)
onLambda <- (DistEnv lore m -> Lambda -> Binder lore (Lambda lore))
-> DistNestT lore m (Lambda -> Binder lore (Lambda lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
forall lore (m :: * -> *).
DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda
  (Lambda lore, Stms lore) -> Lambda lore
forall a b. (a, b) -> a
fst ((Lambda lore, Stms lore) -> Lambda lore)
-> DistNestT lore m (Lambda lore, Stms lore)
-> DistNestT lore m (Lambda lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Binder lore (Lambda lore)
-> DistNestT lore m (Lambda lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Lambda -> Binder lore (Lambda lore)
onLambda Lambda
lam)

newtype DistNestT lore m a
  = DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
  deriving
    ( a -> DistNestT lore m b -> DistNestT lore m a
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
(forall a b. (a -> b) -> DistNestT lore m a -> DistNestT lore m b)
-> (forall a b. a -> DistNestT lore m b -> DistNestT lore m a)
-> Functor (DistNestT lore m)
forall a b. a -> DistNestT lore m b -> DistNestT lore m a
forall a b. (a -> b) -> DistNestT lore m a -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Functor m =>
a -> DistNestT lore m b -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DistNestT lore m b -> DistNestT lore m a
$c<$ :: forall lore (m :: * -> *) a b.
Functor m =>
a -> DistNestT lore m b -> DistNestT lore m a
fmap :: (a -> b) -> DistNestT lore m a -> DistNestT lore m b
$cfmap :: forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
Functor,
      Functor (DistNestT lore m)
a -> DistNestT lore m a
Functor (DistNestT lore m)
-> (forall a. a -> DistNestT lore m a)
-> (forall a b.
    DistNestT lore m (a -> b)
    -> DistNestT lore m a -> DistNestT lore m b)
-> (forall a b c.
    (a -> b -> c)
    -> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c)
-> (forall a b.
    DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b)
-> (forall a b.
    DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a)
-> Applicative (DistNestT lore m)
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall a. a -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a b.
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
forall a b c.
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall lore (m :: * -> *).
Applicative m =>
Functor (DistNestT lore m)
forall lore (m :: * -> *) a.
Applicative m =>
a -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
forall lore (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
$c<* :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
*> :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
$c*> :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
liftA2 :: (a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
$cliftA2 :: forall lore (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
<*> :: DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
$c<*> :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
pure :: a -> DistNestT lore m a
$cpure :: forall lore (m :: * -> *) a.
Applicative m =>
a -> DistNestT lore m a
$cp1Applicative :: forall lore (m :: * -> *).
Applicative m =>
Functor (DistNestT lore m)
Applicative,
      Applicative (DistNestT lore m)
a -> DistNestT lore m a
Applicative (DistNestT lore m)
-> (forall a b.
    DistNestT lore m a
    -> (a -> DistNestT lore m b) -> DistNestT lore m b)
-> (forall a b.
    DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b)
-> (forall a. a -> DistNestT lore m a)
-> Monad (DistNestT lore m)
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a. a -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a b.
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
forall lore (m :: * -> *).
Monad m =>
Applicative (DistNestT lore m)
forall lore (m :: * -> *) a. Monad m => a -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DistNestT lore m a
$creturn :: forall lore (m :: * -> *) a. Monad m => a -> DistNestT lore m a
>> :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
$c>> :: forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
>>= :: DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
$c>>= :: forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
$cp1Monad :: forall lore (m :: * -> *).
Monad m =>
Applicative (DistNestT lore m)
Monad,
      MonadReader (DistEnv lore m),
      MonadWriter (DistRes lore)
    )

liftInner :: (LocalScope lore m, DistLore lore) => m a -> DistNestT lore m a
liftInner :: m a -> DistNestT lore m a
liftInner m a
m = do
  Scope lore
outer_scope <- DistNestT lore m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
 -> DistNestT lore m a)
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$
    WriterT (DistRes lore) m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes lore) m a
 -> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
-> WriterT (DistRes lore) m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
forall a b. (a -> b) -> a -> b
$
      m a -> WriterT (DistRes lore) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT (DistRes lore) m a)
-> m a -> WriterT (DistRes lore) m a
forall a b. (a -> b) -> a -> b
$ do
        Scope lore
inner_scope <- m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
        Scope lore -> m a -> m a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope lore
outer_scope Scope lore -> Scope lore -> Scope lore
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Scope lore
inner_scope) m a
m

instance MonadFreshNames m => MonadFreshNames (DistNestT lore m) where
  getNameSource :: DistNestT lore m VNameSource
getNameSource = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
 -> DistNestT lore m VNameSource)
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource
forall a b. (a -> b) -> a -> b
$ WriterT (DistRes lore) m VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift WriterT (DistRes lore) m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  putNameSource :: VNameSource -> DistNestT lore m ()
putNameSource = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
-> DistNestT lore m ()
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
 -> DistNestT lore m ())
-> (VNameSource
    -> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ())
-> VNameSource
-> DistNestT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT (DistRes lore) m ()
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes lore) m ()
 -> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ())
-> (VNameSource -> WriterT (DistRes lore) m ())
-> VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> WriterT (DistRes lore) m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource

instance (Monad m, ASTLore lore) => HasScope lore (DistNestT lore m) where
  askScope :: DistNestT lore m (Scope lore)
askScope = (DistEnv lore m -> Scope lore) -> DistNestT lore m (Scope lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope

instance (Monad m, ASTLore lore) => LocalScope lore (DistNestT lore m) where
  localScope :: Scope lore -> DistNestT lore m a -> DistNestT lore m a
localScope Scope lore
types = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
 -> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
    DistEnv lore m
env {distScope :: Scope lore
distScope = Scope lore
types Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env}

instance Monad m => MonadLogger (DistNestT lore m) where
  addLog :: Log -> DistNestT lore m ()
addLog Log
msgs = DistRes lore -> DistNestT lore m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell DistRes lore
forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
msgs}

runDistNestT ::
  (MonadLogger m, DistLore lore) =>
  DistEnv lore m ->
  DistNestT lore m (DistAcc lore) ->
  m (Stms lore)
runDistNestT :: DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT DistEnv lore m
env (DistNestT ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
m) = do
  (DistAcc lore
acc, DistRes lore
res) <- WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (DistRes lore) m (DistAcc lore)
 -> m (DistAcc lore, DistRes lore))
-> WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore)
forall a b. (a -> b) -> a -> b
$ ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
-> DistEnv lore m -> WriterT (DistRes lore) m (DistAcc lore)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
m DistEnv lore m
env
  Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog (Log -> m ()) -> Log -> m ()
forall a b. (a -> b) -> a -> b
$ DistRes lore -> Log
forall lore. DistRes lore -> Log
accLog DistRes lore
res
  -- There may be a few final targets remaining - these correspond to
  -- arrays that are identity mapped, and must have statements
  -- inserted here.
  Stms lore -> m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> m (Stms lore)) -> Stms lore -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$
    PostStms lore -> Stms lore
forall lore. PostStms lore -> Stms lore
unPostStms (DistRes lore -> PostStms lore
forall lore. DistRes lore -> PostStms lore
accPostStms DistRes lore
res)
      Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> (PatternT Type, Result) -> Stms lore
identityStms (Targets -> (PatternT Type, Result)
outerTarget (Targets -> (PatternT Type, Result))
-> Targets -> (PatternT Type, Result)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc)
  where
    outermost :: LoopNesting
outermost = Nesting -> LoopNesting
nestingLoop (Nesting -> LoopNesting) -> Nesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
      case DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env of
        (Nesting
nest, []) -> Nesting
nest
        (Nesting
_, Nesting
nest : [Nesting]
_) -> Nesting
nest
    params_to_arrs :: [(VName, VName)]
params_to_arrs =
      ((Param Type, VName) -> (VName, VName))
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param Type -> VName) -> (Param Type, VName) -> (VName, VName)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Param Type -> VName
forall dec. Param dec -> VName
paramName) ([(Param Type, VName)] -> [(VName, VName)])
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$
        LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outermost

    identityStms :: (PatternT Type, Result) -> Stms lore
identityStms (PatternT Type
rem_pat, Result
res) =
      [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm lore] -> Stms lore) -> [Stm lore] -> Stms lore
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> SubExp -> Stm lore)
-> [PatElemT Type] -> Result -> [Stm lore]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT Type -> SubExp -> Stm lore
identityStm (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
rem_pat) Result
res
    identityStm :: PatElemT Type -> SubExp -> Stm lore
identityStm PatElemT Type
pe (Var VName
v)
      | Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
        Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
    identityStm PatElemT Type
pe SubExp
se =
      Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
          Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [LoopNesting -> SubExp
loopNestingWidth LoopNesting
outermost]) SubExp
se

addPostStms :: Monad m => PostStms lore -> DistNestT lore m ()
addPostStms :: PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
ks = DistRes lore -> DistNestT lore m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (DistRes lore -> DistNestT lore m ())
-> DistRes lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ DistRes Any
forall a. Monoid a => a
mempty {accPostStms :: PostStms lore
accPostStms = PostStms lore
ks}

postStm :: Monad m => Stms lore -> DistNestT lore m ()
postStm :: Stms lore -> DistNestT lore m ()
postStm Stms lore
stms = PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms (PostStms lore -> DistNestT lore m ())
-> PostStms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
stms

withStm ::
  (Monad m, DistLore lore) =>
  Stm SOACS ->
  DistNestT lore m a ->
  DistNestT lore m a
withStm :: Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm Stm SOACS
stm = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
 -> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
  DistEnv lore m
env
    { distScope :: Scope lore
distScope =
        Scope SOACS -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Stm SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm SOACS
stm) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env,
      distNest :: Nestings
distNest =
        Names -> Nestings -> Nestings
letBindInInnerNesting Names
provided (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$
          DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env
    }
  where
    provided :: Names
provided = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm

leavingNesting ::
  (MonadFreshNames m, DistLore lore) =>
  DistAcc lore ->
  DistNestT lore m (DistAcc lore)
leavingNesting :: DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting DistAcc lore
acc =
  case Targets -> Maybe ((PatternT Type, Result), Targets)
popInnerTarget (Targets -> Maybe ((PatternT Type, Result), Targets))
-> Targets -> Maybe ((PatternT Type, Result), Targets)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc of
    Maybe ((PatternT Type, Result), Targets)
Nothing ->
      [Char] -> DistNestT lore m (DistAcc lore)
forall a. HasCallStack => [Char] -> a
error [Char]
"The kernel targets list is unexpectedly small"
    Just ((PatternT Type
pat, Result
res), Targets
newtargets)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Seq (Stm lore) -> Bool) -> Seq (Stm lore) -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Seq (Stm lore)
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc -> do
        -- Any statements left over correspond to something that
        -- could not be distributed because it would cause irregular
        -- arrays.  These must be reconstructed into a a Map SOAC
        -- that will be sequentialised. XXX: life would be better if
        -- we were able to distribute irregular parallelism.
        (Nesting Names
_ LoopNesting
inner, [Nesting]
_) <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
        let MapNesting PatternT Type
_ StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs = LoopNesting
inner
            body :: BodyT lore
body = BodyDec lore -> Seq (Stm lore) -> Result -> BodyT lore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (DistAcc lore -> Seq (Stm lore)
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) Result
res
            used_in_body :: Names
used_in_body = BodyT lore -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT lore
body
            ([Param Type]
used_params, [VName]
used_arrs) =
              [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
                ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used_in_body) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
            lam' :: LambdaT lore
lam' =
              Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
                { lambdaParams :: [LParam lore]
lambdaParams = [Param Type]
[LParam lore]
used_params,
                  lambdaBody :: BodyT lore
lambdaBody = BodyT lore
body,
                  lambdaReturnType :: [Type]
lambdaReturnType = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
pat
                }
        Seq (Stm lore)
stms <-
          Binder lore () -> DistNestT lore m (Seq (Stm lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> DistNestT lore m (Seq (Stm lore)))
-> Binder lore () -> DistNestT lore m (Seq (Stm lore))
forall a b. (a -> b) -> a -> b
$
            StmAux () -> Binder lore () -> Binder lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
aux (Binder lore () -> Binder lore ())
-> Binder lore () -> Binder lore ()
forall a b. (a -> b) -> a -> b
$
              Pattern (Lore (BinderT lore (State VNameSource)))
-> SOAC (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (SOAC (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> SOAC (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$
                SubExp -> ScremaForm lore -> [VName] -> SOAC lore
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (LambdaT lore -> ScremaForm lore
forall lore. Lambda lore -> ScremaForm lore
mapSOAC LambdaT lore
lam') [VName]
used_arrs

        DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ DistAcc lore
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Seq (Stm lore)
distStms = Seq (Stm lore)
stms}
      | Bool
otherwise -> do
        -- Any results left over correspond to a Replicate or a Copy in
        -- the parent nesting, depending on whether the argument is a
        -- parameter of the innermost nesting.
        (Nesting Names
_ LoopNesting
inner_nesting, [Nesting]
_) <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
        let w :: SubExp
w = LoopNesting -> SubExp
loopNestingWidth LoopNesting
inner_nesting
            aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
inner_nesting
            inps :: [(Param Type, VName)]
inps = LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
inner_nesting

            remnantStm :: PatElemT Type -> SubExp -> Stm lore
remnantStm PatElemT Type
pe (Var VName
v)
              | Just (Param Type
_, VName
arr) <- ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
inps =
                Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) StmAux ()
StmAux (ExpDec lore)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
            remnantStm PatElemT Type
pe SubExp
se =
              Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) StmAux ()
StmAux (ExpDec lore)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se

            stms :: Seq (Stm lore)
stms =
              [Stm lore] -> Seq (Stm lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm lore] -> Seq (Stm lore)) -> [Stm lore] -> Seq (Stm lore)
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> SubExp -> Stm lore)
-> [PatElemT Type] -> Result -> [Stm lore]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT Type -> SubExp -> Stm lore
remnantStm (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
pat) Result
res

        DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ DistAcc lore
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Seq (Stm lore)
distStms = Seq (Stm lore)
stms}

mapNesting ::
  (MonadFreshNames m, DistLore lore) =>
  PatternT Type ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  DistNestT lore m (DistAcc lore) ->
  DistNestT lore m (DistAcc lore)
mapNesting :: PatternT Type
-> StmAux ()
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
mapNesting PatternT Type
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs DistNestT lore m (DistAcc lore)
m =
  (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local DistEnv lore m -> DistEnv lore m
extend (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistNestT lore m (DistAcc lore)
m
  where
    nest :: Nesting
nest =
      Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty (LoopNesting -> Nesting) -> LoopNesting -> Nesting
forall a b. (a -> b) -> a -> b
$
        PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$
          [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
    extend :: DistEnv lore m -> DistEnv lore m
extend DistEnv lore m
env =
      DistEnv lore m
env
        { distNest :: Nestings
distNest = Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nest (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$ DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env,
          distScope :: Scope lore
distScope = Scope SOACS -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Lambda -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda
lam) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
        }

inNesting ::
  (Monad m, DistLore lore) =>
  KernelNest ->
  DistNestT lore m a ->
  DistNestT lore m a
inNesting :: KernelNest -> DistNestT lore m a -> DistNestT lore m a
inNesting (LoopNesting
outer, [LoopNesting]
nests) = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
 -> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
  DistEnv lore m
env
    { distNest :: Nestings
distNest = (Nesting
inner, [Nesting]
nests'),
      distScope :: Scope lore
distScope = (LoopNesting -> Scope lore) -> [LoopNesting] -> Scope lore
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap LoopNesting -> Scope lore
forall lore. DistLore lore => LoopNesting -> Scope lore
scopeOfLoopNesting (LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
    }
  where
    (Nesting
inner, [Nesting]
nests') =
      case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
        [] -> (LoopNesting -> Nesting
asNesting LoopNesting
outer, [])
        (LoopNesting
inner' : [LoopNesting]
ns) -> (LoopNesting -> Nesting
asNesting LoopNesting
inner', (LoopNesting -> Nesting) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Nesting
asNesting ([LoopNesting] -> [Nesting]) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> a -> b
$ LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
ns)
    asNesting :: LoopNesting -> Nesting
asNesting = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty

bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism = (Stm SOACS -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
isParallelStm (Stms SOACS -> Bool)
-> (Body SOACS -> Stms SOACS) -> Body SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms
  where
    isParallelStm :: Stm SOACS -> Bool
isParallelStm Stm SOACS
stm =
      Exp SOACS -> Bool
isMap (Stm SOACS -> Exp SOACS
forall lore. Stm lore -> Exp lore
stmExp Stm SOACS
stm)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm SOACS
stm))
    isMap :: Exp SOACS -> Bool
isMap Op {} = Bool
True
    isMap (DoLoop [(FParam SOACS, SubExp)]
_ [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
body) = Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body
    isMap Exp SOACS
_ = Bool
False

lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism :: Lambda -> Bool
lambdaContainsParallelism = Body SOACS -> Bool
bodyContainsParallelism (Body SOACS -> Bool) -> (Lambda -> Body SOACS) -> Lambda -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody

distributeMapBodyStms ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  DistAcc lore ->
  Stms SOACS ->
  DistNestT lore m (DistAcc lore)
distributeMapBodyStms :: DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
orig_acc = DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *).
(MonadFreshNames m, Bindable lore, HasSegOp lore, BinderOps lore,
 LocalScope lore m, ExpDec lore ~ (), LetDec lore ~ Type,
 BodyDec lore ~ ()) =>
DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
orig_acc ([Stm SOACS] -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> [Stm SOACS])
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList
  where
    onStms :: DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc [] = DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc
    onStms DistAcc lore
acc (Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Stream w (Sequential accs) lam arrs)) : [Stm SOACS]
stms) = do
      Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
      Stms SOACS
stream_stms <-
        ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Stms SOACS)
-> DistNestT lore m ((), Stms SOACS)
-> DistNestT lore m (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT lore m)))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS (DistNestT lore m)))
-> [VName]
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS (DistNestT lore m)))
Pattern
pat SubExp
w Result
accs LambdaT (Lore (BinderT SOACS (DistNestT lore m)))
Lambda
lam [VName]
arrs) Scope SOACS
types
      (SymbolTable (Wise SOACS)
_, Stms SOACS
stream_stms') <-
        ReaderT
  (Scope SOACS)
  (DistNestT lore m)
  (SymbolTable (Wise SOACS), Stms SOACS)
-> Scope SOACS
-> DistNestT lore m (SymbolTable (Wise SOACS), Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (SimpleOps SOACS
-> Scope SOACS
-> Stms SOACS
-> ReaderT
     (Scope SOACS)
     (DistNestT lore m)
     (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
copyPropagateInStms SimpleOps SOACS
simpleSOACS Scope SOACS
types Stms SOACS
stream_stms) Scope SOACS
types
      DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc ([Stm SOACS] -> DistNestT lore m (DistAcc lore))
-> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList ((Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms SOACS
stream_stms') [Stm SOACS] -> [Stm SOACS] -> [Stm SOACS]
forall a. [a] -> [a] -> [a]
++ [Stm SOACS]
stms
    onStms DistAcc lore
acc (Stm SOACS
stm : [Stm SOACS]
stms) =
      -- It is important that stm is in scope if 'maybeDistributeStm'
      -- wants to distribute, even if this causes the slightly silly
      -- situation that stm is in scope of itself.
      Stm SOACS
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore a.
(Monad m, DistLore lore) =>
Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm Stm SOACS
stm (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
stm (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc [Stm SOACS]
stms

onInnerMap :: Monad m => MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap :: MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap MapLoop
loop DistAcc lore
acc = do
  MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
f <- (DistEnv lore m
 -> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT
     lore m (MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *).
DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distOnInnerMap
  MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
f MapLoop
loop DistAcc lore
acc

onTopLevelStms :: Monad m => Stms SOACS -> DistNestT lore m ()
onTopLevelStms :: Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms = do
  Stms SOACS -> DistNestT lore m (Stms lore)
f <- (DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore))
-> DistNestT lore m (Stms SOACS -> DistNestT lore m (Stms lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
forall lore (m :: * -> *).
DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
distOnTopLevelStms
  Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> DistNestT lore m (Stms lore)
f Stms SOACS
stms

maybeDistributeStm ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  Stm SOACS ->
  DistAcc lore ->
  DistNestT lore m (DistAcc lore)
maybeDistributeStm :: Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
stm DistAcc lore
acc
  | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm SOACS
stm) =
    Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) DistAcc lore
acc
  | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
    DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> Stms SOACS)
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
      (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (Stms SOACS) -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Binder SOACS () -> DistNestT lore m (Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Pattern (Lore (BinderT SOACS (State VNameSource)))
-> SOAC (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT SOACS (State VNameSource)))
Pattern
pat Op SOACS
SOAC (Lore (BinderT SOACS (State VNameSource)))
soac)
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpDec SOACS)
_ (Op (Screma w form arrs))) DistAcc lore
acc
  | Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form =
    -- Only distribute inside the map if we can distribute everything
    -- following the map.
    DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc DistNestT lore m (Maybe (DistAcc lore))
-> (Maybe (DistAcc lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Maybe (DistAcc lore)
Nothing -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
      Just DistAcc lore
acc' -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
Monad m =>
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap (Pattern -> StmAux () -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern
pat (Stm SOACS -> StmAux (ExpDec SOACS)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm SOACS
stm) SubExp
w Lambda
lam [VName]
arrs) DistAcc lore
acc'
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat StmAux (ExpDec SOACS)
aux (DoLoop [] [(FParam SOACS, SubExp)]
val form :: LoopForm SOACS
form@ForLoop {} Body SOACS
body)) DistAcc lore
acc
  | [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT Type
Pattern
pat),
    Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body =
    DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
        | -- XXX: We cannot distribute if this loop depends on
          -- certificates bound within the loop nest (well, we could,
          -- but interchange would not be valid).  This is not a
          -- fundamental restriction, but an artifact of our
          -- certificate representation, which we should probably
          -- rethink.
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
            (LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> StmAux () -> Names
forall a. FreeIn a => a -> Names
freeIn StmAux ()
StmAux (ExpDec SOACS)
aux)
              Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
          Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
            PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs

            -- Simplification is key to hoisting out statements that
            -- were variant to the loop, but invariant to the outer maps
            -- (which are now innermost).
            Stms SOACS
stms <-
              (ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> Scope SOACS -> DistNestT lore m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) (ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
 -> DistNestT lore m (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> DistNestT lore m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
                ((SymbolTable (Wise SOACS), Stms SOACS) -> Stms SOACS)
-> ReaderT
     (Scope SOACS)
     (DistNestT lore m)
     (SymbolTable (Wise SOACS), Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SymbolTable (Wise SOACS), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (ReaderT
   (Scope SOACS)
   (DistNestT lore m)
   (SymbolTable (Wise SOACS), Stms SOACS)
 -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS))
-> (Stms SOACS
    -> ReaderT
         (Scope SOACS)
         (DistNestT lore m)
         (SymbolTable (Wise SOACS), Stms SOACS))
-> Stms SOACS
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS
-> ReaderT
     (Scope SOACS)
     (DistNestT lore m)
     (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms
                  (Stms SOACS
 -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> SeqLoop -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest' ([Int]
-> Pattern
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pattern
pat [(FParam SOACS, SubExp)]
val LoopForm SOACS
form Body SOACS
body)
            Stms SOACS -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms
            DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
      Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
        Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpDec SOACS)
_ (If SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfDec (BranchType SOACS)
ret)) DistAcc lore
acc
  | [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT Type
Pattern
pat),
    Body SOACS -> Bool
bodyContainsParallelism Body SOACS
tbranch Bool -> Bool -> Bool
|| Body SOACS -> Bool
bodyContainsParallelism Body SOACS
fbranch
      Bool -> Bool -> Bool
|| Bool -> Bool
not ((TypeBase ExtShape NoUniqueness -> Bool)
-> [TypeBase ExtShape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase ExtShape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (IfDec (TypeBase ExtShape NoUniqueness)
-> [TypeBase ExtShape NoUniqueness]
forall rt. IfDec rt -> [rt]
ifReturns IfDec (TypeBase ExtShape NoUniqueness)
IfDec (BranchType SOACS)
ret)) =
    DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
            (SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> IfDec (TypeBase ExtShape NoUniqueness) -> Names
forall a. FreeIn a => a -> Names
freeIn IfDec (TypeBase ExtShape NoUniqueness)
IfDec (BranchType SOACS)
ret) Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
          Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
            Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
            let branch :: Branch
branch = [Int]
-> Pattern
-> SubExp
-> Body SOACS
-> Body SOACS
-> IfDec (BranchType SOACS)
-> Branch
Branch [Int]
perm Pattern
pat SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfDec (BranchType SOACS)
ret
            Stms SOACS
stms <-
              (ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> Scope SOACS -> DistNestT lore m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) (ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
 -> DistNestT lore m (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> DistNestT lore m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
                ((SymbolTable (Wise SOACS), Stms SOACS) -> Stms SOACS)
-> ReaderT
     (Scope SOACS)
     (DistNestT lore m)
     (SymbolTable (Wise SOACS), Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SymbolTable (Wise SOACS), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (ReaderT
   (Scope SOACS)
   (DistNestT lore m)
   (SymbolTable (Wise SOACS), Stms SOACS)
 -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS))
-> (Stms SOACS
    -> ReaderT
         (Scope SOACS)
         (DistNestT lore m)
         (SymbolTable (Wise SOACS), Stms SOACS))
-> Stms SOACS
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS
-> ReaderT
     (Scope SOACS)
     (DistNestT lore m)
     (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms
                  (Stms SOACS
 -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> Branch -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stms SOACS)
interchangeBranch KernelNest
nest' Branch
branch
            Stms SOACS -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms
            DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
      Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
        Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (Op (Screma w form arrs))) DistAcc lore
acc
  | Just [Reduce Commutativity
comm Lambda
lam Result
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm SOACS
form,
    Just BinderT SOACS (DistNestT lore m) ()
m <- Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS (DistNestT lore m) ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern
pat SubExp
w Commutativity
comm Lambda
lam ([(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT lore m) ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT lore m) ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
    Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
    (()
_, Stms SOACS
bnds) <- BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (StmAux ()
-> BinderT SOACS (DistNestT lore m) ()
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux BinderT SOACS (DistNestT lore m) ()
m) Scope SOACS
types
    DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc Stms SOACS
bnds

-- Parallelise segmented scatters.
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Scatter w lam ivs as))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
        Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
          KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
          Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
          PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
          Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT lore m (Stms lore)
segmentedScatterKernel KernelNest
nest' [Int]
perm PatternT Type
Pattern
pat Certificates
cs SubExp
w Lambda lore
lam' [VName]
ivs [(SubExp, Int, VName)]
as
          DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
-- Parallelise segmented Hist.
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Hist w ops lam as))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
        Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
          Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
          KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
          PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
          Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
segmentedHistKernel KernelNest
nest' [Int]
perm Certificates
cs SubExp
w [HistOp SOACS]
ops Lambda lore
lam' [VName]
as
          DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
-- Parallelise Index slices if the result is going to be returned
-- directly from the kernel.  This is because we would otherwise have
-- to sequentialise writing the result, which may be costly.
maybeDistributeStm
  stm :: Stm SOACS
stm@( Let
          (Pattern [] [PatElemT (LetDec SOACS)
pe])
          StmAux (ExpDec SOACS)
aux
          (BasicOp (Index VName
arr Slice SubExp
slice))
        )
  DistAcc lore
acc
    | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice,
      VName -> SubExp
Var (PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElemT (LetDec SOACS)
pe) SubExp -> Result -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (PatternT Type, Result) -> Result
forall a b. (a, b) -> b
snd (Targets -> (PatternT Type, Result)
innerTarget (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc)) =
      DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms lore
kernels, Result
_res, KernelNest
nest, DistAcc lore
acc') ->
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
            PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
            Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> Certificates
-> VName
-> Slice SubExp
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> Certificates
-> VName
-> Slice SubExp
-> DistNestT lore m (Stms lore)
segmentedGatherKernel KernelNest
nest (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice
            DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
        Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
          Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
-- If the scan can be distributed by itself, we will turn it into a
-- segmented scan.
--
-- If the scan cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
  | Just ([Scan SOACS]
scans, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form,
    Scan Lambda
lam Result
nes <- [Scan SOACS] -> Scan SOACS
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan SOACS]
scans =
    DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
        | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            Lambda lore
map_lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
map_lam
            Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
            Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$
              KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel KernelNest
nest' [Int]
perm SubExp
w Lambda lore
lam' Lambda lore
map_lam' Result
nes [VName]
arrs
                DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
      Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
        Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
-- if the reduction can be distributed by itself, we will turn it into a
-- segmented reduce.
--
-- If the reduction cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
  | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
    Reduce Commutativity
comm Lambda
lam Result
nes <- [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds =
    DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
        | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest

            Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
            Lambda lore
map_lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
map_lam

            let comm' :: Commutativity
comm'
                  | Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
lam = Commutativity
Commutative
                  | Bool
otherwise = Commutativity
comm

            KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel KernelNest
nest' [Int]
perm SubExp
w Commutativity
comm' Lambda lore
lam' Lambda lore
map_lam' Result
nes [VName]
arrs
              DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
      Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
        Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc = do
  -- This Screma is too complicated for us to immediately do
  -- anything, so split it up and try again.
  Scope SOACS
scope <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
  DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) (Stms SOACS -> Stms SOACS)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
    (((), Stms SOACS) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m ((), Stms SOACS)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT lore m)))
-> SubExp
-> ScremaForm (Lore (BinderT SOACS (DistNestT lore m)))
-> [VName]
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m), Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> ScremaForm (Lore m) -> [VName] -> m ()
dissectScrema Pattern (Lore (BinderT SOACS (DistNestT lore m)))
Pattern
pat SubExp
w ScremaForm (Lore (BinderT SOACS (DistNestT lore m)))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
maybeDistributeStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (BasicOp (Replicate (Shape (SubExp
d : Result
ds)) SubExp
v))) DistAcc lore
acc
  | [Type
t] <- PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern
pat = do
    VName
tmp <- [Char] -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tmp"
    let rowt :: Type
rowt = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
t
        newbnd :: Stm SOACS
newbnd = Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
d (Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam) []
        tmpbnd :: Stm SOACS
tmpbnd =
          Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
tmp Type
rowt]) StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
ds) SubExp
v
        lam :: Lambda
lam =
          Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
            { lambdaReturnType :: [Type]
lambdaReturnType = [Type
rowt],
              lambdaParams :: [LParam SOACS]
lambdaParams = [],
              lambdaBody :: Body SOACS
lambdaBody = Stms SOACS -> Result -> Body SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm Stm SOACS
tmpbnd) [VName -> SubExp
Var VName
tmp]
            }
    Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
newbnd DistAcc lore
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Copy VName
stm_arr))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatternT Type
outerpat VName
arr ->
    Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
-- Opaques are applied to the full array, because otherwise they can
-- drastically inhibit parallelisation in some cases.
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pattern [] [PatElemT (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Opaque (Var VName
stm_arr)))) DistAcc lore
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> Type
forall t. Typed t => t -> Type
typeOf PatElemT Type
PatElemT (LetDec SOACS)
pe =
    DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatternT Type
outerpat VName
arr ->
      Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rearrange [Int]
perm VName
stm_arr))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
    let r :: Int
r = [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        perm' :: [Int]
perm' = [Int
0 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) [Int]
perm
    -- We need to add a copy, because the original map nest
    -- will have produced an array without aliases, and so must we.
    VName
arr' <- [Char] -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> DistNestT lore m VName)
-> [Char] -> DistNestT lore m VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
arr
    Type
arr_t <- VName -> DistNestT lore m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
    Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$
      [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList
        [ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
arr' Type
arr_t]) StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr,
          Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm' VName
arr'
        ]
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Reshape ShapeChange SubExp
reshape VName
stm_arr))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
    let reshape' :: ShapeChange SubExp
reshape' =
          (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (KernelNest -> Result
kernelNestWidths KernelNest
nest)
            ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (ShapeChange SubExp -> Result
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
reshape)
    Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
reshape' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rotate Result
rots VName
stm_arr))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
    let rots' :: Result
rots' = (SubExp -> SubExp) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExp -> SubExp
forall a b. a -> b -> a
const (SubExp -> SubExp -> SubExp) -> SubExp -> SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (KernelNest -> Result
kernelNestWidths KernelNest
nest) Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
rots
    Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Result -> VName -> BasicOp
Rotate Result
rots' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpDec SOACS)
aux (BasicOp (Update VName
arr Slice SubExp
slice (Var VName
v)))) DistAcc lore
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice =
    DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
        | Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm),
          Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res -> do
          PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm
              (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
segmentedUpdateKernel KernelNest
nest' [Int]
perm (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice VName
v
            DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
      Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
-- XXX?  This rule is present to avoid the case where an in-place
-- update is distributed as its own kernel, as this would mean thread
-- then writes the entire array that it updated.  This is problematic
-- because the in-place updates is O(1), but writing the array is
-- O(n).  It is OK if the in-place update is preceded, followed, or
-- nested inside a sequential loop or similar, because that will
-- probably be O(n) by itself.  As a hack, we only distribute if there
-- does not appear to be a loop following.  The better solution is to
-- depend on memory block merging for this optimisation, but it is not
-- ready yet.
maybeDistributeStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (BasicOp (Update VName
arr [DimFix SubExp
i] SubExp
v))) DistAcc lore
acc
  | [Type
t] <- PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern
pat,
    Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp lore -> Bool
forall lore. ExpT lore -> Bool
amortises (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc = do
    let w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t
        et :: Type
et = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1 Type
t
        lam :: Lambda
lam =
          Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
            { lambdaParams :: [LParam SOACS]
lambdaParams = [],
              lambdaReturnType :: [Type]
lambdaReturnType = [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type
et],
              lambdaBody :: Body SOACS
lambdaBody = Stms SOACS -> Result -> Body SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms SOACS
forall a. Monoid a => a
mempty [SubExp
i, SubExp
v]
            }
    Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm (Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda -> [VName] -> [(SubExp, Int, VName)] -> SOAC SOACS
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(SubExp, Int, VName)] -> SOAC lore
Scatter (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) Lambda
lam [] [(SubExp
w, Int
1, VName
arr)]) DistAcc lore
acc
  where
    amortises :: ExpT lore -> Bool
amortises DoLoop {} = Bool
True
    amortises Op {} = Bool
True
    amortises ExpT lore
_ = Bool
False
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpDec SOACS)
aux (BasicOp (Concat Int
d VName
x [VName]
xs SubExp
w))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
_, KernelNest
nest, DistAcc lore
acc') ->
      Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$
        KernelNest -> DistNestT lore m (Maybe (Stms lore))
segmentedConcat KernelNest
nest
          DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) Stm SOACS
stm DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
  where
    segmentedConcat :: KernelNest -> DistNestT lore m (Maybe (Stms lore))
segmentedConcat KernelNest
nest =
      KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int
0] Names
forall a. Monoid a => a
mempty Names
forall a. Monoid a => a
mempty [] (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs) ((PatternT Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> Result
  -> [VName]
  -> BinderT lore m ())
 -> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
        \PatternT Type
pat [(VName, SubExp)]
_ [KernelInput]
_ Result
_ (VName
x' : [VName]
xs') ->
          let d' :: Int
d' = Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
           in Stm (Lore (BinderT lore m)) -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (BinderT lore m)) -> BinderT lore m ())
-> Stm (Lore (BinderT lore m)) -> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
pat StmAux (ExpDec lore)
StmAux (ExpDec SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
d' VName
x' [VName]
xs' SubExp
w
maybeDistributeStm Stm SOACS
bnd DistAcc lore
acc =
  Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc

distributeSingleUnaryStm ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  DistAcc lore ->
  Stm SOACS ->
  VName ->
  (KernelNest -> PatternT Type -> VName -> DistNestT lore m (Stms lore)) ->
  DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm :: DistAcc lore
-> Stm SOACS
-> VName
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm VName
stm_arr KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore)
f =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm),
        (LoopNesting
outer, [LoopNesting]
_) <- KernelNest
nest,
        [(Param Type
arr_p, VName
arr)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        KernelNest -> Names
boundInKernelNest KernelNest
nest Names -> Names -> Names
`namesIntersection` Stm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Stm SOACS
stm
          Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Names
oneName (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
arr_p),
        VName -> KernelNest -> Bool
perfectlyMapped VName
arr KernelNest
nest -> do
        PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
        let outerpat :: PatternT Type
outerpat = LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
        Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
          Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore)
f KernelNest
nest PatternT Type
outerpat VName
arr
          DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
  where
    perfectlyMapped :: VName -> KernelNest -> Bool
perfectlyMapped VName
arr (LoopNesting
outer, [LoopNesting]
nest)
      | [(Param Type
p, VName
arr')] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        VName
arr VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr' =
        case [LoopNesting]
nest of
          [] -> Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
stm_arr
          LoopNesting
x : [LoopNesting]
xs -> VName -> KernelNest -> Bool
perfectlyMapped (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (LoopNesting
x, [LoopNesting]
xs)
      | Bool
otherwise =
        Bool
False

distribute ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  DistAcc lore ->
  DistNestT lore m (DistAcc lore)
distribute :: DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute DistAcc lore
acc =
  DistAcc lore -> Maybe (DistAcc lore) -> DistAcc lore
forall a. a -> Maybe a -> a
fromMaybe DistAcc lore
acc (Maybe (DistAcc lore) -> DistAcc lore)
-> DistNestT lore m (Maybe (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc

mkSegLevel ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel :: DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel = do
  Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
 -> Result
 -> [Char]
 -> ThreadRecommendation
 -> BinderT lore m (SegOpLevel lore))
-> DistNestT
     lore
     m
     (Result
      -> [Char]
      -> ThreadRecommendation
      -> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
  MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (MkSegLevel lore (DistNestT lore m)
 -> DistNestT lore m (MkSegLevel lore (DistNestT lore m)))
-> MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m))
forall a b. (a -> b) -> a -> b
$ \Result
w [Char]
desc ThreadRecommendation
r -> do
    (SegOpLevel lore
lvl, Stms lore
stms) <- DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT lore m (SegOpLevel lore, Stms lore)
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore))
-> DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (m (SegOpLevel lore, Stms lore)
 -> DistNestT lore m (SegOpLevel lore, Stms lore))
-> m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ BinderT lore m (SegOpLevel lore) -> m (SegOpLevel lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT lore m (SegOpLevel lore)
 -> m (SegOpLevel lore, Stms lore))
-> BinderT lore m (SegOpLevel lore)
-> m (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl Result
w [Char]
desc ThreadRecommendation
r
    Stms (Lore (BinderT lore (DistNestT lore m)))
-> BinderT lore (DistNestT lore m) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (DistNestT lore m)))
stms
    SegOpLevel lore
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
forall (m :: * -> *) a. Monad m => a -> m a
return SegOpLevel lore
lvl

distributeIfPossible ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  DistAcc lore ->
  DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible :: DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc = do
  Nestings
nest <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
  Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
  lore
  m
  (Result
   -> [Char]
   -> ThreadRecommendation
   -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
  (Result
 -> [Char]
 -> ThreadRecommendation
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> Nestings
-> Targets
-> Stms lore
-> DistNestT lore m (Maybe (Targets, Stms lore))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m,
 MonadLogger m) =>
MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl Nestings
nest (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc) (DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) DistNestT lore m (Maybe (Targets, Stms lore))
-> (Maybe (Targets, Stms lore)
    -> DistNestT lore m (Maybe (DistAcc lore)))
-> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms lore)
Nothing -> Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (DistAcc lore)
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms lore
kernel) -> do
      Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm Stms lore
kernel
      Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore)))
-> Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall a b. (a -> b) -> a -> b
$
        DistAcc lore -> Maybe (DistAcc lore)
forall a. a -> Maybe a
Just
          DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc
            { distTargets :: Targets
distTargets = Targets
targets,
              distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
            }

distributeSingleStm ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  DistAcc lore ->
  Stm SOACS ->
  DistNestT
    lore
    m
    ( Maybe
        ( PostStms lore,
          Result,
          KernelNest,
          DistAcc lore
        )
    )
distributeSingleStm :: DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd = do
  Nestings
nest <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
  Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
  lore
  m
  (Result
   -> [Char]
   -> ThreadRecommendation
   -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
  (Result
 -> [Char]
 -> ThreadRecommendation
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> Nestings
-> Targets
-> Stms lore
-> DistNestT lore m (Maybe (Targets, Stms lore))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m,
 MonadLogger m) =>
MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl Nestings
nest (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc) (DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) DistNestT lore m (Maybe (Targets, Stms lore))
-> (Maybe (Targets, Stms lore)
    -> DistNestT
         lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms lore)
Nothing -> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms lore
distributed_bnds) ->
      Nestings
-> Targets
-> Stm SOACS
-> DistNestT lore m (Maybe (Result, Targets, KernelNest))
forall (m :: * -> *) t lore.
(MonadFreshNames m, HasScope t m, ASTLore lore) =>
Nestings
-> Targets -> Stm lore -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm SOACS
bnd DistNestT lore m (Maybe (Result, Targets, KernelNest))
-> (Maybe (Result, Targets, KernelNest)
    -> DistNestT
         lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe (Result, Targets, KernelNest)
Nothing -> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. Maybe a
Nothing
        Just (Result
res, Targets
targets', KernelNest
new_kernel_nest) ->
          Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
 -> DistNestT
      lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall a b. (a -> b) -> a -> b
$
            (PostStms lore, Result, KernelNest, DistAcc lore)
-> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. a -> Maybe a
Just
              ( Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
distributed_bnds,
                Result
res,
                KernelNest
new_kernel_nest,
                DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc
                  { distTargets :: Targets
distTargets = Targets
targets',
                    distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
                  }
              )

segmentedScatterKernel ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  KernelNest ->
  [Int] ->
  PatternT Type ->
  Certificates ->
  SubExp ->
  Lambda lore ->
  [VName] ->
  [(SubExp, Int, VName)] ->
  DistNestT lore m (Stms lore)
segmentedScatterKernel :: KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT lore m (Stms lore)
segmentedScatterKernel KernelNest
nest [Int]
perm PatternT Type
scatter_pat Certificates
cs SubExp
scatter_w Lambda lore
lam [VName]
ivs [(SubExp, Int, VName)]
dests = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a scatter is not a reduction or
  -- scan.
  --
  -- First, pretend that the scatter is also part of the nesting.  The
  -- KernelNest we produce here is technically not sensible, but it's
  -- good enough for flatKernel to work.
  let nesting :: LoopNesting
nesting =
        PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
scatter_pat (Certificates -> Attrs -> () -> StmAux ()
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
forall a. Monoid a => a
mempty ()) SubExp
scatter_w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) [VName]
ivs
      nest' :: KernelNest
nest' =
        (PatternT Type, Result) -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (PatternT Type
scatter_pat, BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT lore -> Result) -> BodyT lore -> Result
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) LoopNesting
nesting KernelNest
nest
  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest'

  let (Result
as_ws, [Int]
as_ns, [VName]
as) = [(SubExp, Int, VName)] -> (Result, [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
dests

  -- The input/output arrays ('as') _must_ correspond to some kernel
  -- input, or else the original nested scatter would have been
  -- ill-typed.  Find them.
  [KernelInput]
as_inps <- (VName -> DistNestT lore m KernelInput)
-> [VName] -> DistNestT lore m [KernelInput]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([KernelInput] -> VName -> DistNestT lore m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
kernel_inps) [VName]
as

  Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
  lore
  m
  (Result
   -> [Char]
   -> ThreadRecommendation
   -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel

  let rts :: [Type]
rts =
        ([Type] -> [Type]) -> [[Type]] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
1) ([[Type]] -> [Type]) -> [[Type]] -> [Type]
forall a b. (a -> b) -> a -> b
$
          [Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$
            Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
      (Result
is, Result
vs) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns) (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT lore -> Result) -> BodyT lore -> Result
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam

  -- Maybe add certificates to the indices.
  (Result
is', Stms lore
k_body_stms) <- Binder lore Result -> DistNestT lore m (Result, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore Result -> DistNestT lore m (Result, Stms lore))
-> Binder lore Result -> DistNestT lore m (Result, Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT lore (State VNameSource)))
 -> BinderT lore (State VNameSource) ())
-> Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
    Result
-> (SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
is ((SubExp -> BinderT lore (State VNameSource) SubExp)
 -> Binder lore Result)
-> (SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result
forall a b. (a -> b) -> a -> b
$ \SubExp
i ->
      if Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty
        then SubExp -> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
i
        else Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
 -> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ [Char]
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"scatter_i" (Exp (Lore (BinderT lore (State VNameSource)))
 -> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
i

  let k_body :: KernelBody lore
k_body =
        BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
k_body_stms ([KernelResult] -> KernelBody lore)
-> [KernelResult] -> KernelBody lore
forall a b. (a -> b) -> a -> b
$
          ((SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult)
-> [(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)]
-> (SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace) ([(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult])
-> [(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$
            Result
-> [KernelInput]
-> [[(SubExp, SubExp)]]
-> [(SubExp, KernelInput, [(SubExp, SubExp)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
as_ws [KernelInput]
as_inps ([[(SubExp, SubExp)]]
 -> [(SubExp, KernelInput, [(SubExp, SubExp)])])
-> [[(SubExp, SubExp)]]
-> [(SubExp, KernelInput, [(SubExp, SubExp)])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([(SubExp, SubExp)] -> [[(SubExp, SubExp)]])
-> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a b. (a -> b) -> a -> b
$ Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
is' Result
vs

  (SegOp (SegOpLevel lore) lore
k, Stms lore
k_bnds) <- (Result
 -> [Char]
 -> ThreadRecommendation
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type]
rts KernelBody lore
k_body

  (Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (BinderT lore (State VNameSource) ()
    -> DistNestT lore m (Stms lore))
-> BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinderT lore (State VNameSource) () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT lore (State VNameSource) ()
 -> DistNestT lore m (Stms lore))
-> BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
k_bnds

    let pat :: PatternT Type
pat =
          [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
            [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
              PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource)))
 -> BinderT lore (State VNameSource) ())
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k
  where
    findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
      m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
    bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested scatter encountered."

    inPlaceReturn :: [(VName, SubExp)]
-> (SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace (SubExp
aw, KernelInput
inp, [(SubExp, SubExp)]
is_vs) =
      Result -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
        (Result -> Result
forall a. [a] -> [a]
init Result
ws Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
aw])
        (KernelInput -> VName
kernelInputArray KernelInput
inp)
        [((SubExp -> DimIndex SubExp) -> Result -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Result -> Slice SubExp) -> Result -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids) Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
i], SubExp
v) | (SubExp
i, SubExp
v) <- [(SubExp, SubExp)]
is_vs]
      where
        ([VName]
gtids, Result
ws) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
ispace

segmentedUpdateKernel ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  KernelNest ->
  [Int] ->
  Certificates ->
  VName ->
  Slice SubExp ->
  VName ->
  DistNestT lore m (Stms lore)
segmentedUpdateKernel :: KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
segmentedUpdateKernel KernelNest
nest [Int]
perm Certificates
cs VName
arr Slice SubExp
slice VName
v = do
  ([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let slice_dims :: Result
slice_dims = Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  [VName]
slice_gtids <- Int -> DistNestT lore m VName -> DistNestT lore m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
slice_dims) ([Char] -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")

  let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids Result
slice_dims

  ((Type
res_t, KernelResult
res), Stms lore
kstms) <- Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Type, KernelResult)
 -> DistNestT lore m ((Type, KernelResult), Stms lore))
-> Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    -- Compute indexes into full array.
    SubExp
v' <-
      Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
 -> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
        [Char]
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"v" (Exp (Lore (BinderT lore (State VNameSource)))
 -> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
    Result
slice_is <-
      (TPrimExp Int64 VName -> BinderT lore (State VNameSource) SubExp)
-> [TPrimExp Int64 VName]
-> BinderT lore (State VNameSource) Result
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ([Char]
-> TPrimExp Int64 VName -> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") ([TPrimExp Int64 VName] -> BinderT lore (State VNameSource) Result)
-> [TPrimExp Int64 VName]
-> BinderT lore (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
        Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (VName -> SubExp) -> VName -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids

    let write_is :: Result
write_is = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) [(VName, SubExp)]
base_ispace Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
slice_is
        arr' :: VName
arr' =
          VName -> (KernelInput -> VName) -> Maybe KernelInput -> VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> VName
forall a. HasCallStack => [Char] -> a
error [Char]
"incorrectly typed Update") KernelInput -> VName
kernelInputArray (Maybe KernelInput -> VName) -> Maybe KernelInput -> VName
forall a b. (a -> b) -> a -> b
$
            (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
    Type
arr_t <- VName -> BinderT lore (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr'
    Type
v_t <- SubExp -> BinderT lore (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
    (Type, KernelResult) -> Binder lore (Type, KernelResult)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( Type
v_t,
        Result -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t) VName
arr' [((SubExp -> DimIndex SubExp) -> Result -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix Result
write_is, SubExp
v')]
      )

  Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
  lore
  m
  (Result
   -> [Char]
   -> ThreadRecommendation
   -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
  (SegOp (SegOpLevel lore) lore
k, Stms lore
prestms) <-
    (Result
 -> [Char]
 -> ThreadRecommendation
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] (KernelBody lore
 -> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore))
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall a b. (a -> b) -> a -> b
$
      BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
kstms [KernelResult
res]

  (Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder lore () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore () -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
prestms

    let pat :: PatternT Type
pat =
          [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
            [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
              PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k

segmentedGatherKernel ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  KernelNest ->
  Certificates ->
  VName ->
  Slice SubExp ->
  DistNestT lore m (Stms lore)
segmentedGatherKernel :: KernelNest
-> Certificates
-> VName
-> Slice SubExp
-> DistNestT lore m (Stms lore)
segmentedGatherKernel KernelNest
nest Certificates
cs VName
arr Slice SubExp
slice = do
  let slice_dims :: Result
slice_dims = Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  [VName]
slice_gtids <- Int -> DistNestT lore m VName -> DistNestT lore m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
slice_dims) ([Char] -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")

  ([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids Result
slice_dims

  ((Type
res_t, KernelResult
res), Stms lore
kstms) <- Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Type, KernelResult)
 -> DistNestT lore m ((Type, KernelResult), Stms lore))
-> Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    -- Compute indexes into full array.
    Slice SubExp
slice'' <-
      Slice (TPrimExp Int64 VName)
-> BinderT lore (State VNameSource) (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName)
 -> BinderT lore (State VNameSource) (Slice SubExp))
-> Slice (TPrimExp Int64 VName)
-> BinderT lore (State VNameSource) (Slice SubExp)
forall a b. (a -> b) -> a -> b
$
        Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
slice) (Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
          Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice (Slice SubExp -> Slice (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
    SubExp
v' <- Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
 -> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ [Char]
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"v" (Exp (Lore (BinderT lore (State VNameSource)))
 -> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice''
    Type
v_t <- SubExp -> BinderT lore (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
    (Type, KernelResult) -> Binder lore (Type, KernelResult)
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
v_t, ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify SubExp
v')

  Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
  lore
  m
  (Result
   -> [Char]
   -> ThreadRecommendation
   -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
  (SegOp (SegOpLevel lore) lore
k, Stms lore
prestms) <-
    (Result
 -> [Char]
 -> ThreadRecommendation
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] (KernelBody lore
 -> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore))
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall a b. (a -> b) -> a -> b
$
      BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
kstms [KernelResult
res]

  (Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder lore () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore () -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
prestms

    let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k

segmentedHistKernel ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  KernelNest ->
  [Int] ->
  Certificates ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda lore ->
  [VName] ->
  DistNestT lore m (Stms lore)
segmentedHistKernel :: KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
segmentedHistKernel KernelNest
nest [Int]
perm Certificates
cs SubExp
hist_w [HistOp SOACS]
ops Lambda lore
lam [VName]
arrs = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a Hist is not a reduction or
  -- scan.
  ([(VName, SubExp)]
ispace, [KernelInput]
inputs) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let orig_pat :: PatternT Type
orig_pat =
        [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
          [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
            PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

  -- The input/output arrays _must_ correspond to some kernel input,
  -- or else the original nested Hist would have been ill-typed.
  -- Find them.
  [HistOp SOACS]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> DistNestT lore m (HistOp SOACS))
 -> DistNestT lore m [HistOp SOACS])
-> (HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) ->
    SubExp -> SubExp -> [VName] -> Result -> Lambda -> HistOp SOACS
forall lore.
SubExp -> SubExp -> [VName] -> Result -> Lambda lore -> HistOp lore
SOACS.HistOp SubExp
num_bins SubExp
rf
      ([VName] -> Result -> Lambda -> HistOp SOACS)
-> DistNestT lore m [VName]
-> DistNestT lore m (Result -> Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> DistNestT lore m VName)
-> [VName] -> DistNestT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((KernelInput -> VName)
-> DistNestT lore m KernelInput -> DistNestT lore m VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KernelInput -> VName
kernelInputArray (DistNestT lore m KernelInput -> DistNestT lore m VName)
-> (VName -> DistNestT lore m KernelInput)
-> VName
-> DistNestT lore m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [KernelInput] -> VName -> DistNestT lore m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
inputs) [VName]
dests
      DistNestT lore m (Result -> Lambda -> HistOp SOACS)
-> DistNestT lore m Result
-> DistNestT lore m (Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistNestT lore m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
nes
      DistNestT lore m (Lambda -> HistOp SOACS)
-> DistNestT lore m Lambda -> DistNestT lore m (HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda -> DistNestT lore m Lambda
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda
op

  Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
 -> Result
 -> [Char]
 -> ThreadRecommendation
 -> BinderT lore m (SegOpLevel lore))
-> DistNestT
     lore
     m
     (Result
      -> [Char]
      -> ThreadRecommendation
      -> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
  Lambda -> Binder lore (Lambda lore)
onLambda <- (DistEnv lore m -> Lambda -> Binder lore (Lambda lore))
-> DistNestT lore m (Lambda -> Binder lore (Lambda lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
forall lore (m :: * -> *).
DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda
  let onLambda' :: Lambda -> BinderT lore m (Lambda lore)
onLambda' = ((Lambda lore, Stms lore) -> Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore)
-> BinderT lore m (Lambda lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda lore, Stms lore) -> Lambda lore
forall a b. (a, b) -> a
fst (BinderT lore m (Lambda lore, Stms lore)
 -> BinderT lore m (Lambda lore))
-> (Lambda -> BinderT lore m (Lambda lore, Stms lore))
-> Lambda
-> BinderT lore m (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder lore (Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Lambda lore)
 -> BinderT lore m (Lambda lore, Stms lore))
-> (Lambda -> Binder lore (Lambda lore))
-> Lambda
-> BinderT lore m (Lambda lore, Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Binder lore (Lambda lore)
onLambda
  m (Stms lore) -> DistNestT lore m (Stms lore)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (m (Stms lore) -> DistNestT lore m (Stms lore))
-> m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$
    BinderT lore m () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT lore m () -> m (Stms lore))
-> BinderT lore m () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
      -- It is important not to launch unnecessarily many threads for
      -- histograms, because it may mean we unnecessarily need to reduce
      -- subhistograms as well.
      SegOpLevel lore
lvl <- Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
hist_w SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"seghist" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms
        (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Lambda -> BinderT lore m (Lambda (Lore (BinderT lore m))))
-> SegOpLevel (Lore (BinderT lore m))
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore (BinderT lore m))
-> [VName]
-> BinderT lore m (Stms (Lore (BinderT lore m)))
forall (m :: * -> *).
(MonadBinder m, DistLore (Lore m)) =>
(Lambda -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore m)
-> [VName]
-> m (Stms (Lore m))
histKernel Lambda -> BinderT lore m (Lambda lore)
Lambda -> BinderT lore m (Lambda (Lore (BinderT lore m)))
onLambda' SegOpLevel lore
SegOpLevel (Lore (BinderT lore m))
lvl PatternT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops' Lambda lore
Lambda (Lore (BinderT lore m))
lam [VName]
arrs
  where
    findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
      m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
    bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested Hist encountered."

histKernel ::
  (MonadBinder m, DistLore (Lore m)) =>
  (Lambda SOACS -> m (Lambda (Lore m))) ->
  SegOpLevel (Lore m) ->
  PatternT Type ->
  [(VName, SubExp)] ->
  [KernelInput] ->
  Certificates ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda (Lore m) ->
  [VName] ->
  m (Stms (Lore m))
histKernel :: (Lambda -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore m)
-> [VName]
-> m (Stms (Lore m))
histKernel Lambda -> m (Lambda (Lore m))
onLambda SegOpLevel (Lore m)
lvl PatternT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops Lambda (Lore m)
lam [VName]
arrs = BinderT (Lore m) m () -> m (Stms (Lore m))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT (Lore m) m () -> m (Stms (Lore m)))
-> BinderT (Lore m) m () -> m (Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
  [HistOp (Lore m)]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
 -> BinderT (Lore m) m [HistOp (Lore m)])
-> (HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) -> do
    (Lambda
op', Result
nes', Shape
shape) <- Lambda -> Result -> BinderT (Lore m) m (Lambda, Result, Shape)
forall (m :: * -> *) lore.
(MonadBinder m, Lore m ~ lore) =>
Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
op Result
nes
    Lambda (Lore m)
op'' <- m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m)))
-> m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m))
forall a b. (a -> b) -> a -> b
$ Lambda -> m (Lambda (Lore m))
onLambda Lambda
op'
    HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m)))
-> HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m))
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> Lambda (Lore m)
-> HistOp (Lore m)
forall lore.
SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> Lambda lore
-> HistOp lore
HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes' Shape
shape Lambda (Lore m)
op''

  let isDest :: VName -> Bool
isDest = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem ([VName] -> VName -> Bool) -> [VName] -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> [VName]) -> [HistOp (Lore m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp (Lore m)]
ops'
      inputs' :: [KernelInput]
inputs' = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (KernelInput -> Bool) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isDest (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputArray) [KernelInput]
inputs

  Certificates -> BinderT (Lore m) m () -> BinderT (Lore m) m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT (Lore m) m () -> BinderT (Lore m) m ())
-> BinderT (Lore m) m () -> BinderT (Lore m) m ()
forall a b. (a -> b) -> a -> b
$
    Stms (Lore m) -> BinderT (Lore m) m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore m) -> BinderT (Lore m) m ())
-> BinderT (Lore m) m (Stms (Lore m)) -> BinderT (Lore m) m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm (Lore m) -> BinderT (Lore m) m (Stm (Lore m)))
-> Stms (Lore m) -> BinderT (Lore m) m (Stms (Lore m))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm (Lore m) -> BinderT (Lore m) m (Stm (Lore m))
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm
      (Stms (Lore m) -> BinderT (Lore m) m (Stms (Lore m)))
-> BinderT (Lore m) m (Stms (Lore m))
-> BinderT (Lore m) m (Stms (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel (Lore m)
-> Pattern (Lore m)
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp (Lore m)]
-> Lambda (Lore m)
-> [VName]
-> BinderT (Lore m) m (Stms (Lore m))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
segHist SegOpLevel (Lore m)
lvl PatternT Type
Pattern (Lore m)
orig_pat SubExp
hist_w [(VName, SubExp)]
ispace [KernelInput]
inputs' [HistOp (Lore m)]
ops' Lambda (Lore m)
lam [VName]
arrs

determineReduceOp ::
  (MonadBinder m, Lore m ~ lore) =>
  Lambda SOACS ->
  [SubExp] ->
  m (Lambda SOACS, [SubExp], Shape)
determineReduceOp :: Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
lam Result
nes =
  -- FIXME? We are assuming that the accumulator is a replicate, and
  -- we fish out its value in a gross way.
  case (SubExp -> Maybe VName) -> Result -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar Result
nes of
    Just [VName]
ne_vs' -> do
      let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
      Result
nes' <- [VName] -> (VName -> m SubExp) -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ne_vs' ((VName -> m SubExp) -> m Result)
-> (VName -> m SubExp) -> m Result
forall a b. (a -> b) -> a -> b
$ \VName
ne_v -> do
        Type
ne_v_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
ne_v
        [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"hist_ne" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
ne_v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              Type -> Slice SubExp -> Slice SubExp
fullSlice Type
ne_v_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
                Int -> DimIndex SubExp -> Slice SubExp
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (DimIndex SubExp -> Slice SubExp)
-> DimIndex SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
      (Lambda, Result, Shape) -> m (Lambda, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam', Result
nes', Shape
shape)
    Maybe [VName]
Nothing ->
      (Lambda, Result, Shape) -> m (Lambda, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam, Result
nes, Shape
forall a. Monoid a => a
mempty)

isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap :: Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
  | [Let (Pattern [] [PatElemT (LetDec SOACS)]
pes) StmAux (ExpDec SOACS)
_ (Op (Screma w form arrs))] <-
      Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam,
    Body SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
[PatElemT (LetDec SOACS)]
pes,
    Just Lambda
map_lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
    [VName]
arrs [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) =
    let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
map_lam
     in (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape, Lambda
lam')
  | Bool
otherwise = (Shape
forall a. Monoid a => a
mempty, Lambda
lam)

segmentedScanomapKernel ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  KernelNest ->
  [Int] ->
  SubExp ->
  Lambda lore ->
  Lambda lore ->
  [SubExp] ->
  [VName] ->
  DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel :: KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel KernelNest
nest [Int]
perm SubExp
segment_size Lambda lore
lam Lambda lore
map_lam Result
nes [VName]
arrs = do
  Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
 -> Result
 -> [Char]
 -> ThreadRecommendation
 -> BinderT lore m (SegOpLevel lore))
-> DistNestT
     lore
     m
     (Result
      -> [Char]
      -> ThreadRecommendation
      -> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
  KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
lam) (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
map_lam) Result
nes [] ((PatternT Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> Result
  -> [VName]
  -> BinderT lore m ())
 -> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
    \PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ -> do
      let scan_op :: SegBinOp lore
scan_op = Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
Noncommutative Lambda lore
lam Result
nes' Shape
forall a. Monoid a => a
mempty
      SegOpLevel lore
lvl <- Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segscan" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm lore -> BinderT lore m (Stm lore))
-> Stms lore -> BinderT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> BinderT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm
        (Stms lore -> BinderT lore m (Stms lore))
-> BinderT lore m (Stms lore) -> BinderT lore m (Stms lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segScan SegOpLevel lore
lvl PatternT Type
Pattern lore
pat SubExp
segment_size [SegBinOp lore
scan_op] Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

regularSegmentedRedomapKernel ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  KernelNest ->
  [Int] ->
  SubExp ->
  Commutativity ->
  Lambda lore ->
  Lambda lore ->
  [SubExp] ->
  [VName] ->
  DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel :: KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel KernelNest
nest [Int]
perm SubExp
segment_size Commutativity
comm Lambda lore
lam Lambda lore
map_lam Result
nes [VName]
arrs = do
  Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
 -> Result
 -> [Char]
 -> ThreadRecommendation
 -> BinderT lore m (SegOpLevel lore))
-> DistNestT
     lore
     m
     (Result
      -> [Char]
      -> ThreadRecommendation
      -> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
  KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
lam) (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
map_lam) Result
nes [] ((PatternT Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> Result
  -> [VName]
  -> BinderT lore m ())
 -> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
    \PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ -> do
      let red_op :: SegBinOp lore
red_op = Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda lore
lam Result
nes' Shape
forall a. Monoid a => a
mempty
      SegOpLevel lore
lvl <- Result
-> [Char]
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segred" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm lore -> BinderT lore m (Stm lore))
-> Stms lore -> BinderT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> BinderT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm
        (Stms lore -> BinderT lore m (Stms lore))
-> BinderT lore m (Stms lore) -> BinderT lore m (Stms lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segRed SegOpLevel lore
lvl PatternT Type
Pattern lore
pat SubExp
segment_size [SegBinOp lore
red_op] Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

isSegmentedOp ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  KernelNest ->
  [Int] ->
  Names ->
  Names ->
  [SubExp] ->
  [VName] ->
  ( PatternT Type ->
    [(VName, SubExp)] ->
    [KernelInput] ->
    [SubExp] ->
    [VName] ->
    BinderT lore m ()
  ) ->
  DistNestT lore m (Maybe (Stms lore))
isSegmentedOp :: KernelNest
-> [Int]
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm Names
free_in_op Names
_free_in_fold_op Result
nes [VName]
arrs PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ()
m = MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT (DistNestT lore m) (Stms lore)
 -> DistNestT lore m (Maybe (Stms lore)))
-> MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$ do
  -- We must verify that array inputs to the operation are inputs to
  -- the outermost loop nesting or free in the loop nest.  Nothing
  -- free in the op may be bound by the nest.  Furthermore, the
  -- neutral elements must be free in the loop nest.
  --
  -- We must summarise any names from free_in_op that are bound in the
  -- nest, and describe how to obtain them given segment indices.

  let bound_by_nest :: Names
bound_by_nest = KernelNest -> Names
boundInKernelNest KernelNest
nest

  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest
-> MaybeT (DistNestT lore m) ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest

  Bool
-> MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Names
free_in_op Names -> Names -> Bool
`namesIntersect` Names
bound_by_nest) (MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ())
-> MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ()
forall a b. (a -> b) -> a -> b
$
    [Char] -> MaybeT (DistNestT lore m) ()
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Non-fold lambda uses nest-bound parameters."

  let indices :: [VName]
indices = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace

      prepareNe :: SubExp -> MaybeT (DistNestT lore m) SubExp
prepareNe (Var VName
v)
        | VName
v VName -> Names -> Bool
`nameIn` Names
bound_by_nest =
          [Char] -> MaybeT (DistNestT lore m) SubExp
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Neutral element bound in nest"
      prepareNe SubExp
ne = SubExp -> MaybeT (DistNestT lore m) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
ne

      prepareArr :: VName -> MaybeT (DistNestT lore m) (BinderT lore m VName)
prepareArr VName
arr =
        case (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps of
          Just KernelInput
inp
            | KernelInput -> Result
kernelInputIndices KernelInput
inp Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
indices ->
              BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
 -> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$ VName -> BinderT lore m VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> BinderT lore m VName) -> VName -> BinderT lore m VName
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
          Maybe KernelInput
Nothing
            | Bool -> Bool
not (VName
arr VName -> Names -> Bool
`nameIn` Names
bound_by_nest) ->
              -- This input is something that is free inside
              -- the loop nesting. We will have to replicate
              -- it.
              BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
 -> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$
                [Char] -> Exp (Lore (BinderT lore m)) -> BinderT lore m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp
                  (VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_repd")
                  (BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
          Maybe KernelInput
_ ->
            [Char] -> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Input not free, perfectly mapped, or outermost."

  Result
nes' <- (SubExp -> MaybeT (DistNestT lore m) SubExp)
-> Result -> MaybeT (DistNestT lore m) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> MaybeT (DistNestT lore m) SubExp
prepareNe Result
nes

  [BinderT lore m VName]
mk_arrs <- (VName -> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> [VName] -> MaybeT (DistNestT lore m) [BinderT lore m VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> MaybeT (DistNestT lore m) (BinderT lore m VName)
prepareArr [VName]
arrs

  DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT lore m (Stms lore)
 -> MaybeT (DistNestT lore m) (Stms lore))
-> DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore)
forall a b. (a -> b) -> a -> b
$
    m (Stms lore) -> DistNestT lore m (Stms lore)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (m (Stms lore) -> DistNestT lore m (Stms lore))
-> m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$
      BinderT lore m () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT lore m () -> m (Stms lore))
-> BinderT lore m () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
        [VName]
nested_arrs <- [BinderT lore m VName] -> BinderT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [BinderT lore m VName]
mk_arrs

        let pat :: PatternT Type
pat =
              [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
                [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
                  PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

        PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> BinderT lore m ()
m PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
kernel_inps Result
nes' [VName]
nested_arrs

permutationAndMissing :: PatternT Type -> [SubExp] -> Maybe ([Int], [PatElemT Type])
permutationAndMissing :: PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
pat Result
res = do
  let pes :: [PatElemT Type]
pes = PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
pat
      ([PatElemT Type]
_used, [PatElemT Type]
unused) =
        (PatElemT Type -> Bool)
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> Names -> Bool
`nameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res) (VName -> Bool)
-> (PatElemT Type -> VName) -> PatElemT Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
pes
      res_expanded :: Result
res_expanded = Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
unused
  [Int]
perm <- (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
pes Result -> Result -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` Result
res_expanded
  ([Int], [PatElemT Type]) -> Maybe ([Int], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [PatElemT Type]
unused)

-- Add extra pattern elements to every kernel nesting level.
expandKernelNest ::
  MonadFreshNames m =>
  [PatElemT Type] ->
  KernelNest ->
  m KernelNest
expandKernelNest :: [PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pes (LoopNesting
outer_nest, [LoopNesting]
inner_nests) = do
  let outer_size :: Result
outer_size =
        LoopNesting -> SubExp
loopNestingWidth LoopNesting
outer_nest SubExp -> Result -> Result
forall a. a -> [a] -> [a]
:
        (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
      inner_sizes :: [Result]
inner_sizes = Result -> [Result]
forall a. [a] -> [[a]]
tails (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
  LoopNesting
outer_nest' <- LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
outer_nest Result
outer_size
  [LoopNesting]
inner_nests' <- (LoopNesting -> Result -> m LoopNesting)
-> [LoopNesting] -> [Result] -> m [LoopNesting]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM LoopNesting -> Result -> m LoopNesting
expandWith [LoopNesting]
inner_nests [Result]
inner_sizes
  KernelNest -> m KernelNest
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopNesting
outer_nest', [LoopNesting]
inner_nests')
  where
    expandWith :: LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
nest Result
dims = do
      [PatElemT Type]
pes' <- (PatElemT Type -> m (PatElemT Type))
-> [PatElemT Type] -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Result -> PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) dec.
(MonadFreshNames m, Typed dec) =>
Result -> PatElemT dec -> m (PatElemT Type)
expandPatElemWith Result
dims) [PatElemT Type]
pes
      LoopNesting -> m LoopNesting
forall (m :: * -> *) a. Monad m => a -> m a
return
        LoopNesting
nest
          { loopNestingPattern :: PatternT Type
loopNestingPattern =
              [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
                PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements (LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
nest) [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type]
pes'
          }

    expandPatElemWith :: Result -> PatElemT dec -> m (PatElemT Type)
expandPatElemWith Result
dims PatElemT dec
pe = do
      VName
name <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> m VName) -> [Char] -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe
      PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return
        PatElemT dec
pe
          { patElemName :: VName
patElemName = VName
name,
            patElemDec :: Type
patElemDec = PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe Type -> Shape -> Type
`arrayOfShape` Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
dims
          }

kernelOrNot ::
  (MonadFreshNames m, DistLore lore) =>
  Certificates ->
  Stm SOACS ->
  DistAcc lore ->
  PostStms lore ->
  DistAcc lore ->
  Maybe (Stms lore) ->
  DistNestT lore m (DistAcc lore)
kernelOrNot :: Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
_ DistAcc lore
_ Maybe (Stms lore)
Nothing =
  Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs Stm SOACS
bnd) DistAcc lore
acc
kernelOrNot Certificates
cs Stm SOACS
_ DistAcc lore
_ PostStms lore
kernels DistAcc lore
acc' (Just Stms lore
bnds) = do
  PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
  Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> Stms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Stm lore) -> Stms lore -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms lore
bnds
  DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'

distributeMap ::
  (MonadFreshNames m, LocalScope lore m, DistLore lore) =>
  MapLoop ->
  DistAcc lore ->
  DistNestT lore m (DistAcc lore)
distributeMap :: MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distributeMap (MapLoop Pattern
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) DistAcc lore
acc =
  DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute
    (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PatternT Type
-> StmAux ()
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
PatternT Type
-> StmAux ()
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
mapNesting
      PatternT Type
Pattern
pat
      StmAux ()
aux
      SubExp
w
      Lambda
lam
      [VName]
arrs
      (DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc' Stms SOACS
lam_bnds)
  where
    acc' :: DistAcc lore
acc' =
      DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc
        { distTargets :: Targets
distTargets =
            (PatternT Type, Result) -> Targets -> Targets
pushInnerTarget
              (PatternT Type
Pattern
pat, Body SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam)
              (Targets -> Targets) -> Targets -> Targets
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc,
          distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
        }

    lam_bnds :: Stms SOACS
lam_bnds = Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam