{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
module Futhark.Pass.ExtractKernels.DistributeNests
  ( MapLoop(..)
  , mapLoopStm

  , bodyContainsParallelism
  , lambdaContainsParallelism
  , determineReduceOp
  , incrementalFlattening
  , histKernel

  , DistEnv (..)
  , DistAcc (..)
  , runDistNestT
  , DistNestT

  , distributeMap

  , distribute
  , distributeSingleStm
  , distributeMapBodyStms
  , addStmsToAcc
  , addStmToAcc
  , permutationAndMissing
  , addPostStms
  , postStm
  , inNesting
  )
where

import Control.Arrow (first)
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Writer.Strict
import Control.Monad.Trans.Maybe
import Data.Maybe
import Data.List (find, partition, tails)

import Futhark.Representation.AST
import qualified Futhark.Representation.SOACS as SOACS
import Futhark.Representation.SOACS.SOAC hiding (HistOp, histDest)
import Futhark.Representation.SOACS (SOACS)
import Futhark.Representation.SOACS.Simplify (simpleSOACS)
import Futhark.Representation.SegOp
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.CopyPropagate
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Util
import Futhark.Util.Log

scopeForSOACs :: SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs :: Scope lore -> Scope SOACS
scopeForSOACs = Scope lore -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope

data MapLoop = MapLoop SOACS.Pattern Certificates SubExp SOACS.Lambda [VName]

mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm (MapLoop Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs) = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam) [VName]
arrs

data DistEnv lore m =
  DistEnv { DistEnv lore m -> Nestings
distNest :: Nestings
          , DistEnv lore m -> Scope lore
distScope :: Scope lore
          , DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
distOnTopLevelStms :: Stms SOACS -> DistNestT lore m (Stms lore)
          , DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distOnInnerMap :: MapLoop -> DistAcc lore
                           -> DistNestT lore m (DistAcc lore)
          , DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
distOnSOACSStms :: Stm SOACS -> Binder lore (Stms lore)
          , DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda :: Lambda SOACS -> Binder lore (Lambda lore)
          , DistEnv lore m -> MkSegLevel lore m
distSegLevel :: MkSegLevel lore m
          }

data DistAcc lore =
  DistAcc { DistAcc lore -> Targets
distTargets :: Targets
          , DistAcc lore -> Stms lore
distStms :: Stms lore
          }

data DistRes lore =
  DistRes { DistRes lore -> PostStms lore
accPostStms :: PostStms lore
          , DistRes lore -> Log
accLog :: Log
          }

instance Semigroup (DistRes lore) where
  DistRes PostStms lore
ks1 Log
log1 <> :: DistRes lore -> DistRes lore -> DistRes lore
<> DistRes PostStms lore
ks2 Log
log2 =
    PostStms lore -> Log -> DistRes lore
forall lore. PostStms lore -> Log -> DistRes lore
DistRes (PostStms lore
ks1 PostStms lore -> PostStms lore -> PostStms lore
forall a. Semigroup a => a -> a -> a
<> PostStms lore
ks2) (Log
log1 Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log2)

instance Monoid (DistRes lore) where
  mempty :: DistRes lore
mempty = PostStms lore -> Log -> DistRes lore
forall lore. PostStms lore -> Log -> DistRes lore
DistRes PostStms lore
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

newtype PostStms lore = PostStms { PostStms lore -> Stms lore
unPostStms :: Stms lore }

instance Semigroup (PostStms lore) where
  PostStms Stms lore
xs <> :: PostStms lore -> PostStms lore -> PostStms lore
<> PostStms Stms lore
ys = Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms (Stms lore -> PostStms lore) -> Stms lore -> PostStms lore
forall a b. (a -> b) -> a -> b
$ Stms lore
ys Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
xs

instance Monoid (PostStms lore) where
  mempty :: PostStms lore
mempty = Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
forall a. Monoid a => a
mempty

typeEnvFromDistAcc :: DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc :: DistAcc lore -> Scope lore
typeEnvFromDistAcc = PatternT Type -> Scope lore
forall lore attr.
(LetAttr lore ~ attr) =>
PatternT attr -> Scope lore
scopeOfPattern (PatternT Type -> Scope lore)
-> (DistAcc lore -> PatternT Type) -> DistAcc lore -> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternT Type, Result) -> PatternT Type
forall a b. (a, b) -> a
fst ((PatternT Type, Result) -> PatternT Type)
-> (DistAcc lore -> (PatternT Type, Result))
-> DistAcc lore
-> PatternT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (PatternT Type, Result)
outerTarget (Targets -> (PatternT Type, Result))
-> (DistAcc lore -> Targets)
-> DistAcc lore
-> (PatternT Type, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets

addStmsToAcc :: Stms lore -> DistAcc lore -> DistAcc lore
addStmsToAcc :: Stms lore -> DistAcc lore -> DistAcc lore
addStmsToAcc Stms lore
stms DistAcc lore
acc =
  DistAcc lore
acc { distStms :: Stms lore
distStms = Stms lore
stms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc }

addStmToAcc :: (MonadFreshNames m, DistLore lore) =>
               Stm SOACS -> DistAcc lore
            -> DistNestT lore m (DistAcc lore)
addStmToAcc :: Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc = do
  Stm SOACS -> Binder lore (Stms lore)
onSoacs <- (DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore))
-> DistNestT lore m (Stm SOACS -> Binder lore (Stms lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
forall lore (m :: * -> *).
DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
distOnSOACSStms
  (Stms lore
stm', Stms lore
_) <- Binder lore (Stms lore) -> DistNestT lore m (Stms lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Stms lore)
 -> DistNestT lore m (Stms lore, Stms lore))
-> Binder lore (Stms lore)
-> DistNestT lore m (Stms lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Binder lore (Stms lore)
onSoacs Stm SOACS
stm
  DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc { distStms :: Stms lore
distStms = Stms lore
stm' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc }

soacsLambda :: (MonadFreshNames m, DistLore lore) =>
               Lambda SOACS -> DistNestT lore m (Lambda lore)
soacsLambda :: Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam = do
  Lambda -> Binder lore (Lambda lore)
onLambda <- (DistEnv lore m -> Lambda -> Binder lore (Lambda lore))
-> DistNestT lore m (Lambda -> Binder lore (Lambda lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
forall lore (m :: * -> *).
DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda
  (Lambda lore, Stms lore) -> Lambda lore
forall a b. (a, b) -> a
fst ((Lambda lore, Stms lore) -> Lambda lore)
-> DistNestT lore m (Lambda lore, Stms lore)
-> DistNestT lore m (Lambda lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Binder lore (Lambda lore)
-> DistNestT lore m (Lambda lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Lambda -> Binder lore (Lambda lore)
onLambda Lambda
lam)

newtype DistNestT lore m a =
  DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
  deriving (a -> DistNestT lore m b -> DistNestT lore m a
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
(forall a b. (a -> b) -> DistNestT lore m a -> DistNestT lore m b)
-> (forall a b. a -> DistNestT lore m b -> DistNestT lore m a)
-> Functor (DistNestT lore m)
forall a b. a -> DistNestT lore m b -> DistNestT lore m a
forall a b. (a -> b) -> DistNestT lore m a -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Functor m =>
a -> DistNestT lore m b -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DistNestT lore m b -> DistNestT lore m a
$c<$ :: forall lore (m :: * -> *) a b.
Functor m =>
a -> DistNestT lore m b -> DistNestT lore m a
fmap :: (a -> b) -> DistNestT lore m a -> DistNestT lore m b
$cfmap :: forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
Functor, Functor (DistNestT lore m)
a -> DistNestT lore m a
Functor (DistNestT lore m)
-> (forall a. a -> DistNestT lore m a)
-> (forall a b.
    DistNestT lore m (a -> b)
    -> DistNestT lore m a -> DistNestT lore m b)
-> (forall a b c.
    (a -> b -> c)
    -> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c)
-> (forall a b.
    DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b)
-> (forall a b.
    DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a)
-> Applicative (DistNestT lore m)
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall a. a -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a b.
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
forall a b c.
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall lore (m :: * -> *).
Applicative m =>
Functor (DistNestT lore m)
forall lore (m :: * -> *) a.
Applicative m =>
a -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
forall lore (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
$c<* :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
*> :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
$c*> :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
liftA2 :: (a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
$cliftA2 :: forall lore (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
<*> :: DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
$c<*> :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
pure :: a -> DistNestT lore m a
$cpure :: forall lore (m :: * -> *) a.
Applicative m =>
a -> DistNestT lore m a
$cp1Applicative :: forall lore (m :: * -> *).
Applicative m =>
Functor (DistNestT lore m)
Applicative, Applicative (DistNestT lore m)
a -> DistNestT lore m a
Applicative (DistNestT lore m)
-> (forall a b.
    DistNestT lore m a
    -> (a -> DistNestT lore m b) -> DistNestT lore m b)
-> (forall a b.
    DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b)
-> (forall a. a -> DistNestT lore m a)
-> Monad (DistNestT lore m)
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a. a -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a b.
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
forall lore (m :: * -> *).
Monad m =>
Applicative (DistNestT lore m)
forall lore (m :: * -> *) a. Monad m => a -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DistNestT lore m a
$creturn :: forall lore (m :: * -> *) a. Monad m => a -> DistNestT lore m a
>> :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
$c>> :: forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
>>= :: DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
$c>>= :: forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
$cp1Monad :: forall lore (m :: * -> *).
Monad m =>
Applicative (DistNestT lore m)
Monad,
            MonadReader (DistEnv lore m),
            MonadWriter (DistRes lore))

instance MonadTrans (DistNestT lore) where
  lift :: m a -> DistNestT lore m a
lift = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
 -> DistNestT lore m a)
-> (m a -> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
-> m a
-> DistNestT lore m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT (DistRes lore) m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes lore) m a
 -> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
-> (m a -> WriterT (DistRes lore) m a)
-> m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> WriterT (DistRes lore) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance MonadFreshNames m => MonadFreshNames (DistNestT lore m) where
  getNameSource :: DistNestT lore m VNameSource
getNameSource = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
 -> DistNestT lore m VNameSource)
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource
forall a b. (a -> b) -> a -> b
$ WriterT (DistRes lore) m VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift WriterT (DistRes lore) m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  putNameSource :: VNameSource -> DistNestT lore m ()
putNameSource = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
-> DistNestT lore m ()
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
 -> DistNestT lore m ())
-> (VNameSource
    -> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ())
-> VNameSource
-> DistNestT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT (DistRes lore) m ()
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes lore) m ()
 -> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ())
-> (VNameSource -> WriterT (DistRes lore) m ())
-> VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> WriterT (DistRes lore) m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource

instance (Monad m, Attributes lore) => HasScope lore (DistNestT lore m) where
  askScope :: DistNestT lore m (Scope lore)
askScope = (DistEnv lore m -> Scope lore) -> DistNestT lore m (Scope lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope

instance (Monad m, Attributes lore) => LocalScope lore (DistNestT lore m) where
  localScope :: Scope lore -> DistNestT lore m a -> DistNestT lore m a
localScope Scope lore
types = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
 -> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
    DistEnv lore m
env { distScope :: Scope lore
distScope = Scope lore
types Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env }

instance Monad m => MonadLogger (DistNestT lore m) where
  addLog :: Log -> DistNestT lore m ()
addLog Log
msgs = DistRes lore -> DistNestT lore m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell DistRes lore
forall a. Monoid a => a
mempty { accLog :: Log
accLog = Log
msgs }

runDistNestT :: (MonadLogger m, DistLore lore) =>
                DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT :: DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT DistEnv lore m
env (DistNestT ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
m) = do
  (DistAcc lore
acc, DistRes lore
res) <- WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (DistRes lore) m (DistAcc lore)
 -> m (DistAcc lore, DistRes lore))
-> WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore)
forall a b. (a -> b) -> a -> b
$ ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
-> DistEnv lore m -> WriterT (DistRes lore) m (DistAcc lore)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
m DistEnv lore m
env
  Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog (Log -> m ()) -> Log -> m ()
forall a b. (a -> b) -> a -> b
$ DistRes lore -> Log
forall lore. DistRes lore -> Log
accLog DistRes lore
res
  -- There may be a few final targets remaining - these correspond to
  -- arrays that are identity mapped, and must have statements
  -- inserted here.
  Stms lore -> m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> m (Stms lore)) -> Stms lore -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$
    PostStms lore -> Stms lore
forall lore. PostStms lore -> Stms lore
unPostStms (DistRes lore -> PostStms lore
forall lore. DistRes lore -> PostStms lore
accPostStms DistRes lore
res) Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>
    (PatternT Type, Result) -> Stms lore
identityStms (Targets -> (PatternT Type, Result)
outerTarget (Targets -> (PatternT Type, Result))
-> Targets -> (PatternT Type, Result)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc)
  where outermost :: LoopNesting
outermost = Nesting -> LoopNesting
nestingLoop (Nesting -> LoopNesting) -> Nesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
                    case DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env of (Nesting
nest, []) -> Nesting
nest
                                         (Nesting
_, Nesting
nest : [Nesting]
_) -> Nesting
nest
        params_to_arrs :: [(VName, VName)]
params_to_arrs = ((Param Type, VName) -> (VName, VName))
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param Type -> VName) -> (Param Type, VName) -> (VName, VName)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Param Type -> VName
forall attr. Param attr -> VName
paramName) ([(Param Type, VName)] -> [(VName, VName)])
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$
                         LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outermost

        identityStms :: (PatternT Type, Result) -> Stms lore
identityStms (PatternT Type
rem_pat, Result
res) =
          [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm lore] -> Stms lore) -> [Stm lore] -> Stms lore
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> SubExp -> Stm lore)
-> [PatElemT Type] -> Result -> [Stm lore]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT Type -> SubExp -> Stm lore
identityStm (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT Type
rem_pat) Result
res
        identityStm :: PatElemT Type -> SubExp -> Stm lore
identityStm PatElemT Type
pe (Var VName
v)
          | Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
              Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
        identityStm PatElemT Type
pe SubExp
se =
          Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
          Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [LoopNesting -> SubExp
loopNestingWidth LoopNesting
outermost]) SubExp
se

addPostStms :: Monad m => PostStms lore -> DistNestT lore m ()
addPostStms :: PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
ks = DistRes lore -> DistNestT lore m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (DistRes lore -> DistNestT lore m ())
-> DistRes lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ DistRes Any
forall a. Monoid a => a
mempty { accPostStms :: PostStms lore
accPostStms = PostStms lore
ks }

postStm :: Monad m => Stms lore -> DistNestT lore m ()
postStm :: Stms lore -> DistNestT lore m ()
postStm Stms lore
stms = PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms (PostStms lore -> DistNestT lore m ())
-> PostStms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
stms

withStm :: (Monad m, DistLore lore) =>
           Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm :: Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm Stm SOACS
stm = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
 -> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
  DistEnv lore m
env { distScope :: Scope lore
distScope =
          Scope SOACS -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Stm SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm SOACS
stm) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
      , distNest :: Nestings
distNest =
          Names -> Nestings -> Nestings
letBindInInnerNesting Names
provided (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$
          DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env
      }
  where provided :: Names
provided = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm

mapNesting :: (Monad m, DistLore lore) =>
              PatternT Type -> Certificates -> SubExp -> Lambda SOACS -> [VName]
           -> DistNestT lore m a
           -> DistNestT lore m a
mapNesting :: PatternT Type
-> Certificates
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m a
-> DistNestT lore m a
mapNesting PatternT Type
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
 -> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
  DistEnv lore m
env { distNest :: Nestings
distNest = Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nest (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$ DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env
      , distScope :: Scope lore
distScope =  Scope SOACS -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Lambda -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda
lam) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
      }
  where nest :: Nesting
nest = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty (LoopNesting -> Nesting) -> LoopNesting -> Nesting
forall a b. (a -> b) -> a -> b
$
               PatternT Type
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$
               [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs

inNesting :: (Monad m, DistLore lore) =>
             KernelNest -> DistNestT lore m a -> DistNestT lore m a
inNesting :: KernelNest -> DistNestT lore m a -> DistNestT lore m a
inNesting (LoopNesting
outer, [LoopNesting]
nests) = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
 -> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
  DistEnv lore m
env { distNest :: Nestings
distNest = (Nesting
inner, [Nesting]
nests')
      , distScope :: Scope lore
distScope =  (LoopNesting -> Scope lore) -> [LoopNesting] -> Scope lore
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap LoopNesting -> Scope lore
forall lore. DistLore lore => LoopNesting -> Scope lore
scopeOfLoopNesting (LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
      }
  where (Nesting
inner, [Nesting]
nests') =
          case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
            []            -> (LoopNesting -> Nesting
asNesting LoopNesting
outer, [])
            (LoopNesting
inner' : [LoopNesting]
ns) -> (LoopNesting -> Nesting
asNesting LoopNesting
inner', (LoopNesting -> Nesting) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Nesting
asNesting ([LoopNesting] -> [Nesting]) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> a -> b
$ LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
ns)
        asNesting :: LoopNesting -> Nesting
asNesting = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty

bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism = (Stm SOACS -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp SOACS -> Bool
forall lore. ExpT lore -> Bool
isMap (Exp SOACS -> Bool)
-> (Stm SOACS -> Exp SOACS) -> Stm SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Exp SOACS
forall lore. Stm lore -> Exp lore
stmExp) (Stms SOACS -> Bool)
-> (Body SOACS -> Stms SOACS) -> Body SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms
  where isMap :: ExpT lore -> Bool
isMap Op{} = Bool
True
        isMap ExpT lore
_ = Bool
False

lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism :: Lambda -> Bool
lambdaContainsParallelism = Body SOACS -> Bool
bodyContainsParallelism (Body SOACS -> Bool) -> (Lambda -> Body SOACS) -> Lambda -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody

-- Enable if you want the cool new versioned code.  Beware: may be
-- slower in practice.  Caveat emptor (and you are the emptor).
incrementalFlattening :: Bool
incrementalFlattening :: Bool
incrementalFlattening = Maybe String -> Bool
forall a. Maybe a -> Bool
isJust (Maybe String -> Bool) -> Maybe String -> Bool
forall a b. (a -> b) -> a -> b
$ String -> [(String, String)] -> Maybe String
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup String
"FUTHARK_INCREMENTAL_FLATTENING" [(String, String)]
unixEnvironment

leavingNesting :: MonadFreshNames m =>
                  DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting :: DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting DistAcc lore
acc =
  case Targets -> Maybe ((PatternT Type, Result), Targets)
popInnerTarget (Targets -> Maybe ((PatternT Type, Result), Targets))
-> Targets -> Maybe ((PatternT Type, Result), Targets)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc of
   Maybe ((PatternT Type, Result), Targets)
Nothing ->
     String -> DistNestT lore m (DistAcc lore)
forall a. HasCallStack => String -> a
error String
"The kernel targets list is unexpectedly small"
   Just ((PatternT Type, Result)
_, Targets
newtargets) ->
     DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc { distTargets :: Targets
distTargets = Targets
newtargets }

distributeMapBodyStms :: (MonadFreshNames m, DistLore lore) => DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms :: DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
orig_acc = DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *).
(MonadFreshNames m, Bindable lore, HasSegOp lore, BinderOps lore,
 ExpAttr lore ~ (), LetAttr lore ~ Type, BodyAttr lore ~ ()) =>
DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
orig_acc ([Stm SOACS] -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> [Stm SOACS])
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList
  where
    onStms :: DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc [] = DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc

    onStms DistAcc lore
acc (Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Stream w (Sequential accs) lam arrs)):[Stm SOACS]
stms) = do
      Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
      Stms SOACS
stream_stms <-
        ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Stms SOACS)
-> DistNestT lore m ((), Stms SOACS)
-> DistNestT lore m (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT lore m)))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS (DistNestT lore m)))
-> [VName]
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS (DistNestT lore m)))
Pattern
pat SubExp
w Result
accs LambdaT (Lore (BinderT SOACS (DistNestT lore m)))
Lambda
lam [VName]
arrs) Scope SOACS
types
      (SymbolTable (Wise SOACS)
_, Stms SOACS
stream_stms') <-
        ReaderT
  (Scope SOACS)
  (DistNestT lore m)
  (SymbolTable (Wise SOACS), Stms SOACS)
-> Scope SOACS
-> DistNestT lore m (SymbolTable (Wise SOACS), Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (SimpleOps SOACS
-> Scope SOACS
-> Stms SOACS
-> ReaderT
     (Scope SOACS)
     (DistNestT lore m)
     (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
copyPropagateInStms SimpleOps SOACS
simpleSOACS Scope SOACS
types Stms SOACS
stream_stms) Scope SOACS
types
      DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc ([Stm SOACS] -> DistNestT lore m (DistAcc lore))
-> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList ((Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms SOACS
stream_stms') [Stm SOACS] -> [Stm SOACS] -> [Stm SOACS]
forall a. [a] -> [a] -> [a]
++ [Stm SOACS]
stms

    onStms DistAcc lore
acc (Stm SOACS
stm:[Stm SOACS]
stms) =
      -- It is important that stm is in scope if 'maybeDistributeStm'
      -- wants to distribute, even if this causes the slightly silly
      -- situation that stm is in scope of itself.
      Stm SOACS
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore a.
(Monad m, DistLore lore) =>
Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm Stm SOACS
stm (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
stm (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc [Stm SOACS]
stms

onInnerMap :: Monad m => MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap :: MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap MapLoop
loop DistAcc lore
acc = do
  MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
f <- (DistEnv lore m
 -> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT
     lore m (MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *).
DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distOnInnerMap
  MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
f MapLoop
loop DistAcc lore
acc

onTopLevelStms :: Monad m => Stms SOACS -> DistNestT lore m ()
onTopLevelStms :: Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms = do
  Stms SOACS -> DistNestT lore m (Stms lore)
f <- (DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore))
-> DistNestT lore m (Stms SOACS -> DistNestT lore m (Stms lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
forall lore (m :: * -> *).
DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
distOnTopLevelStms
  Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> DistNestT lore m (Stms lore)
f Stms SOACS
stms

maybeDistributeStm :: (MonadFreshNames m, DistLore lore) =>
                      Stm SOACS -> DistAcc lore
                   -> DistNestT lore m (DistAcc lore)

maybeDistributeStm :: Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat StmAux (ExpAttr SOACS)
_ (Op (Screma w form arrs))) DistAcc lore
acc
  | Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form =
  -- Only distribute inside the map if we can distribute everything
  -- following the map.
  DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc DistNestT lore m (Maybe (DistAcc lore))
-> (Maybe (DistAcc lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (DistAcc lore)
Nothing -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
    Just DistAcc lore
acc' -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
Monad m =>
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap (Pattern -> Certificates -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern
pat (Stm SOACS -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm SOACS
bnd) SubExp
w Lambda
lam [VName]
arrs) DistAcc lore
acc'

maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat StmAux (ExpAttr SOACS)
_ (DoLoop [] [(FParam SOACS, SubExp)]
val form :: LoopForm SOACS
form@ForLoop{} Body SOACS
body)) DistAcc lore
acc
  | [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements PatternT Type
Pattern
pat), Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
        Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
          PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
          KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
          Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs

          Stms SOACS
bnds <- ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> Scope SOACS -> DistNestT lore m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT
                  (KernelNest
-> SeqLoop -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest' ([Int]
-> Pattern
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pattern
pat [(FParam SOACS, SubExp)]
val LoopForm SOACS
form Body SOACS
body)) Scope SOACS
types
          Stms SOACS -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
bnds
          DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc

maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpAttr SOACS)
_ (If SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfAttr (BranchType SOACS)
ret)) DistAcc lore
acc
  | [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements PatternT Type
Pattern
pat),
    Body SOACS -> Bool
bodyContainsParallelism Body SOACS
tbranch Bool -> Bool -> Bool
|| Body SOACS -> Bool
bodyContainsParallelism Body SOACS
fbranch Bool -> Bool -> Bool
||
    Bool -> Bool
not ((TypeBase ExtShape NoUniqueness -> Bool)
-> [TypeBase ExtShape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase ExtShape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (IfAttr (TypeBase ExtShape NoUniqueness)
-> [TypeBase ExtShape NoUniqueness]
forall rt. IfAttr rt -> [rt]
ifReturns IfAttr (TypeBase ExtShape NoUniqueness)
IfAttr (BranchType SOACS)
ret)) =
    DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
          (SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> IfAttr (TypeBase ExtShape NoUniqueness) -> Names
forall a. FreeIn a => a -> Names
freeIn IfAttr (TypeBase ExtShape NoUniqueness)
IfAttr (BranchType SOACS)
ret) Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
          Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
            -- We need to pretend pat_unused was used anyway, by adding
            -- it to the kernel nest.
            Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
            Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
            let branch :: Branch
branch = [Int]
-> Pattern
-> SubExp
-> Body SOACS
-> Body SOACS
-> IfAttr (BranchType SOACS)
-> Branch
Branch [Int]
perm Pattern
pat SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfAttr (BranchType SOACS)
ret
            Stms SOACS
stms <- ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> Scope SOACS -> DistNestT lore m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (KernelNest
-> Branch -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stms SOACS)
interchangeBranch KernelNest
nest' Branch
branch) Scope SOACS
types
            Stms SOACS -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms
            DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
      Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
        Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc

maybeDistributeStm (Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
  | Just [Reduce Commutativity
comm Lambda
lam Result
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm SOACS
form,
    Just BinderT SOACS (DistNestT lore m) ()
m <- Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS (DistNestT lore m) ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern
pat SubExp
w Commutativity
comm Lambda
lam ([(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT lore m) ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT lore m) ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
      Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
      (()
_, Stms SOACS
bnds) <- BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Certificates
-> BinderT SOACS (DistNestT lore m) ()
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs BinderT SOACS (DistNestT lore m) ()
m) Scope SOACS
types
      DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc Stms SOACS
bnds

-- Parallelise segmented scatters.
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Scatter w lam ivs as))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
        Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
          KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
          Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
          PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
          Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT lore m (Stms lore)
segmentedScatterKernel KernelNest
nest' [Int]
perm PatternT Type
Pattern
pat Certificates
cs SubExp
w Lambda lore
lam' [VName]
ivs [(SubExp, Int, VName)]
as
          DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc

-- Parallelise segmented Hist.
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Hist w ops lam as))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
        Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
          Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
          KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
          PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
          Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
segmentedHistKernel KernelNest
nest' [Int]
perm Certificates
cs SubExp
w [HistOp SOACS]
ops Lambda lore
lam' [VName]
as
          DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc

-- If the scan can be distributed by itself, we will turn it into a
-- segmented scan.
--
-- If the scan cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
  | Just ([Scan SOACS]
scans, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form,
    Scan Lambda
lam Result
nes <- [Scan SOACS] -> Scan SOACS
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan SOACS]
scans =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
          KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
          Lambda lore
map_lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
map_lam
          Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$
            KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel KernelNest
nest' [Int]
perm SubExp
w Lambda lore
lam' Lambda lore
map_lam' Result
nes [VName]
arrs DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
            Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc

-- if the reduction can be distributed by itself, we will turn it into a
-- segmented reduce.
--
-- If the reduction cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
  | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
    Reduce Commutativity
comm Lambda
lam Result
nes <- [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds,
    Lambda -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda
map_lam Bool -> Bool -> Bool
|| Bool
incrementalFlattening =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
          KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest

          Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
          Lambda lore
map_lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
map_lam

          let comm' :: Commutativity
comm' | Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
lam = Commutativity
Commutative
                    | Bool
otherwise             = Commutativity
comm

          KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel KernelNest
nest' [Int]
perm SubExp
w Commutativity
comm' Lambda lore
lam' Lambda lore
map_lam' Result
nes [VName]
arrs DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
            Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc

maybeDistributeStm (Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
  | Bool
incrementalFlattening Bool -> Bool -> Bool
|| Maybe ([Reduce SOACS], Lambda) -> Bool
forall a. Maybe a -> Bool
isNothing (ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form) = do
  -- This with-loop is too complicated for us to immediately do
  -- anything, so split it up and try again.
  Scope SOACS
scope <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
  DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) (Stms SOACS -> Stms SOACS)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m ((), Stms SOACS)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT lore m)))
-> SubExp
-> ScremaForm (Lore (BinderT SOACS (DistNestT lore m)))
-> [VName]
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m), Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> ScremaForm (Lore m) -> [VName] -> m ()
dissectScrema Pattern (Lore (BinderT SOACS (DistNestT lore m)))
Pattern
pat SubExp
w ScremaForm (Lore (BinderT SOACS (DistNestT lore m)))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope

maybeDistributeStm (Let Pattern
pat StmAux (ExpAttr SOACS)
aux (BasicOp (Replicate (Shape (SubExp
d:Result
ds)) SubExp
v))) DistAcc lore
acc
  | [Type
t] <- PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern
pat = do
      -- XXX: We need a temporary dummy binding to prevent an empty
      -- map body.  The kernel extractor does not like empty map
      -- bodies.
      VName
tmp <- String -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tmp"
      let rowt :: Type
rowt = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
t
          newbnd :: Stm SOACS
newbnd = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpAttr SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
d (Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam) []
          tmpbnd :: Stm SOACS
tmpbnd = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
tmp Type
rowt]) StmAux (ExpAttr SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
                   BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
ds) SubExp
v
          lam :: Lambda
lam = Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaReturnType :: [Type]
lambdaReturnType = [Type
rowt]
                       , lambdaParams :: [LParam SOACS]
lambdaParams = []
                       , lambdaBody :: Body SOACS
lambdaBody = Stms SOACS -> Result -> Body SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm Stm SOACS
tmpbnd) [VName -> SubExp
Var VName
tmp]
                       }
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
newbnd DistAcc lore
acc

maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp Copy{})) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatternT Type
outerpat VName
arr ->
  Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr

-- Opaques are applied to the full array, because otherwise they can
-- drastically inhibit parallelisation in some cases.
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let (Pattern [] [PatElemT (LetAttr SOACS)
pe]) StmAux (ExpAttr SOACS)
aux (BasicOp Opaque{})) DistAcc lore
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> Type
forall t. Typed t => t -> Type
typeOf PatElemT Type
PatElemT (LetAttr SOACS)
pe =
      DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatternT Type
outerpat VName
arr ->
      Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr

maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Rearrange [Int]
perm VName
_))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
    let r :: Int
r = [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        perm' :: [Int]
perm' = [Int
0..Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
r) [Int]
perm
    -- We need to add a copy, because the original map nest
    -- will have produced an array without aliases, and so must we.
    VName
arr' <- String -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> DistNestT lore m VName)
-> String -> DistNestT lore m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr
    Type
arr_t <- VName -> DistNestT lore m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
    Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList
      [Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
arr' Type
arr_t]) StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr,
       Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm' VName
arr']

maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Reshape ShapeChange SubExp
reshape VName
_))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
    let reshape' :: ShapeChange SubExp
reshape' = (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (KernelNest -> Result
kernelNestWidths KernelNest
nest) ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++
                   (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (ShapeChange SubExp -> Result
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
reshape)
    Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
reshape' VName
arr

maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Rotate Result
rots VName
_))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm ((KernelNest
  -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
 -> DistNestT lore m (DistAcc lore))
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
    let rots' :: Result
rots' = (SubExp -> SubExp) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExp -> SubExp
forall a b. a -> b -> a
const (SubExp -> SubExp -> SubExp) -> SubExp -> SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) (KernelNest -> Result
kernelNestWidths KernelNest
nest) Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
rots
    Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Result -> VName -> BasicOp
Rotate Result
rots' VName
arr

maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpAttr SOACS)
aux (BasicOp (Update VName
arr Slice SubExp
slice (Var VName
v)))) DistAcc lore
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice =
    DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm),
        Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res -> do
          PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
              KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
segmentedUpdateKernel KernelNest
nest' [Int]
perm (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) VName
arr Slice SubExp
slice VName
v
            DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'

    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc

-- XXX?  This rule is present to avoid the case where an in-place
-- update is distributed as its own kernel, as this would mean thread
-- then writes the entire array that it updated.  This is problematic
-- because the in-place updates is O(1), but writing the array is
-- O(n).  It is OK if the in-place update is preceded, followed, or
-- nested inside a sequential loop or similar, because that will
-- probably be O(n) by itself.  As a hack, we only distribute if there
-- does not appear to be a loop following.  The better solution is to
-- depend on memory block merging for this optimisation, but it is not
-- ready yet.
maybeDistributeStm (Let Pattern
pat StmAux (ExpAttr SOACS)
aux (BasicOp (Update VName
arr [DimFix SubExp
i] SubExp
v))) DistAcc lore
acc
  | [Type
t] <- PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern
pat,
    Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp lore -> Bool
forall lore. ExpT lore -> Bool
amortises (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc = do
      let w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t
          et :: Type
et = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1 Type
t
          lam :: Lambda
lam = Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaParams :: [LParam SOACS]
lambdaParams = []
                       , lambdaReturnType :: [Type]
lambdaReturnType = [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32, Type
et]
                       , lambdaBody :: Body SOACS
lambdaBody = Stms SOACS -> Result -> Body SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms SOACS
forall a. Monoid a => a
mempty [SubExp
i, SubExp
v] }
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm (Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpAttr SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda -> [VName] -> [(SubExp, Int, VName)] -> SOAC SOACS
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(SubExp, Int, VName)] -> SOAC lore
Scatter (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) Lambda
lam [] [(SubExp
w, Int
1, VName
arr)]) DistAcc lore
acc
  where amortises :: ExpT lore -> Bool
amortises DoLoop{} = Bool
True
        amortises Op{} = Bool
True
        amortises ExpT lore
_ = Bool
False

maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Concat Int
d VName
x [VName]
xs SubExp
w))) DistAcc lore
acc =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
_, KernelNest
nest, DistAcc lore
acc') ->
      Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$
      KernelNest -> DistNestT lore m (Maybe (Stms lore))
segmentedConcat KernelNest
nest DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
      Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) Stm SOACS
stm DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
      Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc

  where segmentedConcat :: KernelNest -> DistNestT lore m (Maybe (Stms lore))
segmentedConcat KernelNest
nest =
          KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int
0] SubExp
w Names
forall a. Monoid a => a
mempty Names
forall a. Monoid a => a
mempty [] (VName
xVName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:[VName]
xs) ((PatternT Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> Result
  -> [VName]
  -> [VName]
  -> BinderT lore m ())
 -> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
          \PatternT Type
pat [(VName, SubExp)]
_ [KernelInput]
_ Result
_ (VName
x':[VName]
xs') [VName]
_ ->
            let d' :: Int
d' = Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            in Stm (Lore (BinderT lore m)) -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (BinderT lore m)) -> BinderT lore m ())
-> Stm (Lore (BinderT lore m)) -> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
pat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
d' VName
x' [VName]
xs' SubExp
w

maybeDistributeStm Stm SOACS
bnd DistAcc lore
acc =
  Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc

distributeSingleUnaryStm :: (MonadFreshNames m, DistLore lore) =>
                            DistAcc lore -> Stm SOACS
                         -> (KernelNest -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
                         -> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm :: DistAcc lore
-> Stm SOACS
-> (KernelNest
    -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore)
f =
  DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
  lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
    -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
      | Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
bnd),
        (LoopNesting
outer, [LoopNesting]
inners) <- KernelNest
nest,
        [(Param Type
arr_p, VName
arr)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        KernelNest -> Names
boundInKernelNest KernelNest
nest Names -> Names -> Names
`namesIntersection` Stm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Stm SOACS
bnd
        Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Names
oneName (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
arr_p) -> do
          PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
          let outerpat :: PatternT Type
outerpat = LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
          Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
 -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
            (VName
arr', Stms lore
pre_stms) <- VName -> [LoopNesting] -> DistNestT lore m (VName, Stms lore)
forall (m :: * -> *) lore lore.
(HasScope lore m, MonadFreshNames m, LetAttr lore ~ Type,
 ExpAttr lore ~ ()) =>
VName -> [LoopNesting] -> m (VName, Stms lore)
repeatMissing VName
arr (LoopNesting
outerLoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
:[LoopNesting]
inners)
            Stms lore
f_stms <- Stms lore
-> DistNestT lore m (Stms lore) -> DistNestT lore m (Stms lore)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms lore
pre_stms (DistNestT lore m (Stms lore) -> DistNestT lore m (Stms lore))
-> DistNestT lore m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore)
f KernelNest
nest PatternT Type
outerpat VName
arr'
            Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> Stms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ Stms lore
pre_stms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
f_stms
            DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
    Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
  where -- | For an imperfectly mapped array, repeat the missing
        -- dimensions to make it look like it was in fact perfectly
        -- mapped.
        repeatMissing :: VName -> [LoopNesting] -> m (VName, Stms lore)
repeatMissing VName
arr [LoopNesting]
inners = do
          Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          let shapes :: [Shape]
shapes = VName -> Type -> [LoopNesting] -> [Shape]
forall shape u.
ArrayShape shape =>
VName -> TypeBase shape u -> [LoopNesting] -> [Shape]
determineRepeats VName
arr Type
arr_t [LoopNesting]
inners
          if (Shape -> Bool) -> [Shape] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
==Result -> Shape
forall d. [d] -> ShapeBase d
Shape []) [Shape]
shapes then (VName, Stms lore) -> m (VName, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, Stms lore
forall a. Monoid a => a
mempty)
            else do
            let ([Shape]
outer_shapes, Shape
inner_shape) = [Shape] -> Type -> ([Shape], Shape)
repeatShapes [Shape]
shapes Type
arr_t
                arr_t' :: Type
arr_t' = [Shape] -> Shape -> Type -> Type
repeatDims [Shape]
outer_shapes Shape
inner_shape Type
arr_t
            VName
arr' <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr
            (VName, Stms lore) -> m (VName, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr', Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
arr' Type
arr_t']) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                          BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Shape] -> Shape -> VName -> BasicOp
Repeat [Shape]
outer_shapes Shape
inner_shape VName
arr)

        determineRepeats :: VName -> TypeBase shape u -> [LoopNesting] -> [Shape]
determineRepeats VName
arr TypeBase shape u
arr_t [LoopNesting]
nests
          | ([LoopNesting]
skipped, LoopNesting
arr_nest:[LoopNesting]
nests') <- (LoopNesting -> Bool)
-> [LoopNesting] -> ([LoopNesting], [LoopNesting])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (VName -> LoopNesting -> Bool
hasInput VName
arr) [LoopNesting]
nests,
            [(Param Type
arr_p, VName
_)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
arr_nest =
              Result -> Shape
forall d. [d] -> ShapeBase d
Shape ((LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
skipped) Shape -> [Shape] -> [Shape]
forall a. a -> [a] -> [a]
:
              VName -> TypeBase shape u -> [LoopNesting] -> [Shape]
determineRepeats (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
arr_p) (TypeBase shape u -> TypeBase shape u
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase shape u
arr_t) [LoopNesting]
nests'
          | Bool
otherwise =
              Result -> Shape
forall d. [d] -> ShapeBase d
Shape ((LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
nests) Shape -> [Shape] -> [Shape]
forall a. a -> [a] -> [a]
: Int -> Shape -> [Shape]
forall a. Int -> a -> [a]
replicate (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
arr_t) (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [])

        hasInput :: VName -> LoopNesting -> Bool
hasInput VName
arr LoopNesting
nest
          | [(Param Type
_, VName
arr')] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
nest, VName
arr' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr = Bool
True
          | Bool
otherwise = Bool
False


distribute :: (MonadFreshNames m, DistLore lore) => DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute :: DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute DistAcc lore
acc =
  DistAcc lore -> Maybe (DistAcc lore) -> DistAcc lore
forall a. a -> Maybe a -> a
fromMaybe DistAcc lore
acc (Maybe (DistAcc lore) -> DistAcc lore)
-> DistNestT lore m (Maybe (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc

mkSegLevel :: (MonadFreshNames m, DistLore lore) => DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel :: DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel = do
  Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
 -> Result
 -> String
 -> ThreadRecommendation
 -> BinderT lore m (SegOpLevel lore))
-> DistNestT
     lore
     m
     (Result
      -> String
      -> ThreadRecommendation
      -> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
  MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (MkSegLevel lore (DistNestT lore m)
 -> DistNestT lore m (MkSegLevel lore (DistNestT lore m)))
-> MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m))
forall a b. (a -> b) -> a -> b
$ \Result
w String
desc ThreadRecommendation
r -> do
    Scope lore
scope <- BinderT lore (DistNestT lore m) (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
    (SegOpLevel lore
lvl, Stms lore
stms) <- DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT lore m (SegOpLevel lore, Stms lore)
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore))
-> DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SegOpLevel lore, Stms lore)
 -> DistNestT lore m (SegOpLevel lore, Stms lore))
-> m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ BinderT lore m (SegOpLevel lore)
-> Scope lore -> m (SegOpLevel lore, Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl Result
w String
desc ThreadRecommendation
r) Scope lore
scope
    Stms (Lore (BinderT lore (DistNestT lore m)))
-> BinderT lore (DistNestT lore m) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (DistNestT lore m)))
stms
    SegOpLevel lore
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
forall (m :: * -> *) a. Monad m => a -> m a
return SegOpLevel lore
lvl

distributeIfPossible :: (MonadFreshNames m, DistLore lore) => DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible :: DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc = do
  Nestings
nest <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
  Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
  lore
  m
  (Result
   -> String
   -> ThreadRecommendation
   -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
  (Result
 -> String
 -> ThreadRecommendation
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> Nestings
-> Targets
-> Stms lore
-> DistNestT lore m (Maybe (Targets, Stms lore))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m,
 MonadLogger m) =>
MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl Nestings
nest (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc) (DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) DistNestT lore m (Maybe (Targets, Stms lore))
-> (Maybe (Targets, Stms lore)
    -> DistNestT lore m (Maybe (DistAcc lore)))
-> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms lore)
Nothing -> Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (DistAcc lore)
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms lore
kernel) -> do
      Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm Stms lore
kernel
      Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore)))
-> Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Maybe (DistAcc lore)
forall a. a -> Maybe a
Just DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc { distTargets :: Targets
distTargets = Targets
targets
                            , distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
                            }

distributeSingleStm :: (MonadFreshNames m, DistLore lore) =>
                       DistAcc lore -> Stm SOACS
                    -> DistNestT lore m (Maybe (PostStms lore,
                                                Result,
                                                KernelNest,
                                                 DistAcc lore))
distributeSingleStm :: DistAcc lore
-> Stm SOACS
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd = do
  Nestings
nest <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
  Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
  lore
  m
  (Result
   -> String
   -> ThreadRecommendation
   -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
  (Result
 -> String
 -> ThreadRecommendation
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> Nestings
-> Targets
-> Stms lore
-> DistNestT lore m (Maybe (Targets, Stms lore))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m,
 MonadLogger m) =>
MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl Nestings
nest (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc) (DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) DistNestT lore m (Maybe (Targets, Stms lore))
-> (Maybe (Targets, Stms lore)
    -> DistNestT
         lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms lore)
Nothing -> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms lore
distributed_bnds) ->
      Nestings
-> Targets
-> Stm SOACS
-> DistNestT lore m (Maybe (Result, Targets, KernelNest))
forall (m :: * -> *) t lore.
(MonadFreshNames m, HasScope t m, Attributes lore) =>
Nestings
-> Targets -> Stm lore -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm SOACS
bnd DistNestT lore m (Maybe (Result, Targets, KernelNest))
-> (Maybe (Result, Targets, KernelNest)
    -> DistNestT
         lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe (Result, Targets, KernelNest)
Nothing -> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. Maybe a
Nothing
        Just (Result
res, Targets
targets', KernelNest
new_kernel_nest) ->
          Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
 -> DistNestT
      lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
     lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall a b. (a -> b) -> a -> b
$ (PostStms lore, Result, KernelNest, DistAcc lore)
-> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. a -> Maybe a
Just (Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
distributed_bnds,
                         Result
res,
                         KernelNest
new_kernel_nest,
                         DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc { distTargets :: Targets
distTargets = Targets
targets'
                                 , distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
                                 })

segmentedScatterKernel :: (MonadFreshNames m, DistLore lore) =>
                          KernelNest
                       -> [Int]
                       -> PatternT Type
                       -> Certificates
                       -> SubExp
                       -> Lambda lore
                       -> [VName] -> [(SubExp,Int,VName)]
                       -> DistNestT lore m (Stms lore)
segmentedScatterKernel :: KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT lore m (Stms lore)
segmentedScatterKernel KernelNest
nest [Int]
perm PatternT Type
scatter_pat Certificates
cs SubExp
scatter_w Lambda lore
lam [VName]
ivs [(SubExp, Int, VName)]
dests = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a scatter is not a reduction or
  -- scan.
  --
  -- First, pretend that the scatter is also part of the nesting.  The
  -- KernelNest we produce here is technically not sensible, but it's
  -- good enough for flatKernel to work.
  let nest' :: KernelNest
nest' = (PatternT Type, Result) -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (PatternT Type
scatter_pat, BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT lore -> Result) -> BodyT lore -> Result
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
              (PatternT Type
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
scatter_pat Certificates
cs SubExp
scatter_w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) [VName]
ivs) KernelNest
nest
  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest'

  let (Result
as_ws, [Int]
as_ns, [VName]
as) = [(SubExp, Int, VName)] -> (Result, [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
dests

  -- The input/output arrays ('as') _must_ correspond to some kernel
  -- input, or else the original nested scatter would have been
  -- ill-typed.  Find them.
  [KernelInput]
as_inps <- (VName -> DistNestT lore m KernelInput)
-> [VName] -> DistNestT lore m [KernelInput]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([KernelInput] -> VName -> DistNestT lore m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
kernel_inps) [VName]
as

  Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
  lore
  m
  (Result
   -> String
   -> ThreadRecommendation
   -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel

  let rts :: [Type]
rts = ([Type] -> [Type]) -> [[Type]] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
1) ([[Type]] -> [Type]) -> [[Type]] -> [Type]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$
            Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
      (Result
is,Result
vs) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns) (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT lore -> Result) -> BodyT lore -> Result
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam

  -- Maybe add certificates to the indices.
  (Result
is', Stms lore
k_body_stms) <- Binder lore Result -> DistNestT lore m (Result, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore Result -> DistNestT lore m (Result, Stms lore))
-> Binder lore Result -> DistNestT lore m (Result, Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT lore (State VNameSource)))
 -> BinderT lore (State VNameSource) ())
-> Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
    Result
-> (SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
is ((SubExp -> BinderT lore (State VNameSource) SubExp)
 -> Binder lore Result)
-> (SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result
forall a b. (a -> b) -> a -> b
$ \SubExp
i ->
      if Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty
      then SubExp -> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
i
      else Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
 -> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ String
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"scatter_i" (Exp (Lore (BinderT lore (State VNameSource)))
 -> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
i

  let k_body :: KernelBody lore
k_body = BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
k_body_stms ([KernelResult] -> KernelBody lore)
-> [KernelResult] -> KernelBody lore
forall a b. (a -> b) -> a -> b
$
               ((SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult)
-> [(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)]
-> (SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace) ([(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult])
-> [(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$
               Result
-> [KernelInput]
-> [[(SubExp, SubExp)]]
-> [(SubExp, KernelInput, [(SubExp, SubExp)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
as_ws [KernelInput]
as_inps ([[(SubExp, SubExp)]]
 -> [(SubExp, KernelInput, [(SubExp, SubExp)])])
-> [[(SubExp, SubExp)]]
-> [(SubExp, KernelInput, [(SubExp, SubExp)])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([(SubExp, SubExp)] -> [[(SubExp, SubExp)]])
-> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a b. (a -> b) -> a -> b
$ Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
is' Result
vs

  (SegOp (SegOpLevel lore) lore
k, Stms lore
k_bnds) <- (Result
 -> String
 -> ThreadRecommendation
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type]
rts KernelBody lore
k_body

  (Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (BinderT lore (State VNameSource) ()
    -> DistNestT lore m (Stms lore))
-> BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinderT lore (State VNameSource) () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT lore (State VNameSource) ()
 -> DistNestT lore m (Stms lore))
-> BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
k_bnds

    let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
              PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource)))
 -> BinderT lore (State VNameSource) ())
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k
  where findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
          m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
        bad :: a
bad = String -> a
forall a. HasCallStack => String -> a
error String
"Ill-typed nested scatter encountered."

        inPlaceReturn :: [(VName, SubExp)]
-> (SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace (SubExp
aw, KernelInput
inp, [(SubExp, SubExp)]
is_vs) =
          Result -> VName -> [(Result, SubExp)] -> KernelResult
WriteReturns (Result -> Result
forall a. [a] -> [a]
init Result
wsResult -> Result -> Result
forall a. [a] -> [a] -> [a]
++[SubExp
aw]) (KernelInput -> VName
kernelInputArray KernelInput
inp)
          [ ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids)Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++[SubExp
i], SubExp
v) | (SubExp
i,SubExp
v) <- [(SubExp, SubExp)]
is_vs ]
          where ([VName]
gtids,Result
ws) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
ispace

segmentedUpdateKernel :: (MonadFreshNames m, DistLore lore) =>
                         KernelNest
                      -> [Int]
                      -> Certificates
                      -> VName
                      -> Slice SubExp
                      -> VName
                      -> DistNestT lore m (Stms lore)
segmentedUpdateKernel :: KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
segmentedUpdateKernel KernelNest
nest [Int]
perm Certificates
cs VName
arr Slice SubExp
slice VName
v = do
  ([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let slice_dims :: Result
slice_dims = Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  [VName]
slice_gtids <- Int -> DistNestT lore m VName -> DistNestT lore m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
slice_dims) (String -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid_slice")

  let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids Result
slice_dims

  ((Type
res_t, KernelResult
res), Stms lore
kstms) <- Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Type, KernelResult)
 -> DistNestT lore m ((Type, KernelResult), Stms lore))
-> Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall a b. (a -> b) -> a -> b
$ do

    -- Compute indexes into full array.
    SubExp
v' <- Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
 -> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
          String
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"v" (Exp (Lore (BinderT lore (State VNameSource)))
 -> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
    let pexp :: SubExp -> PrimExp VName
pexp = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32
    Result
slice_is <- (PrimExp VName -> BinderT lore (State VNameSource) SubExp)
-> [PrimExp VName] -> BinderT lore (State VNameSource) Result
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (String
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index" (ExpT lore -> BinderT lore (State VNameSource) SubExp)
-> (PrimExp VName -> BinderT lore (State VNameSource) (ExpT lore))
-> PrimExp VName
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> BinderT lore (State VNameSource) (ExpT lore)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp) ([PrimExp VName] -> BinderT lore (State VNameSource) Result)
-> [PrimExp VName] -> BinderT lore (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
                Slice (PrimExp VName) -> [PrimExp VName] -> [PrimExp VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex SubExp -> DimIndex (PrimExp VName))
-> Slice SubExp -> Slice (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> PrimExp VName)
-> DimIndex SubExp -> DimIndex (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> PrimExp VName
pexp) Slice SubExp
slice) ([PrimExp VName] -> [PrimExp VName])
-> [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ (VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> PrimExp VName
pexp (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids

    let write_is :: Result
write_is = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) [(VName, SubExp)]
base_ispace Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
slice_is
        arr' :: VName
arr' = VName -> (KernelInput -> VName) -> Maybe KernelInput -> VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> VName
forall a. HasCallStack => String -> a
error String
"incorrectly typed Update") KernelInput -> VName
kernelInputArray (Maybe KernelInput -> VName) -> Maybe KernelInput -> VName
forall a b. (a -> b) -> a -> b
$
               (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
    Type
arr_t <- VName -> BinderT lore (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr'
    Type
v_t <- SubExp -> BinderT lore (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
    (Type, KernelResult) -> Binder lore (Type, KernelResult)
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
v_t,
            Result -> VName -> [(Result, SubExp)] -> KernelResult
WriteReturns (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t) VName
arr' [(Result
write_is, SubExp
v')])

  Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
  lore
  m
  (Result
   -> String
   -> ThreadRecommendation
   -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
  (SegOp (SegOpLevel lore) lore
k, Stms lore
prestms) <- (Result
 -> String
 -> ThreadRecommendation
 -> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] (KernelBody lore
 -> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore))
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall a b. (a -> b) -> a -> b
$
                  BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
kstms [KernelResult
res]

  (Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder lore () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore () -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
prestms

    let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
              PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k

segmentedHistKernel :: (MonadFreshNames m, DistLore lore) =>
                       KernelNest
                    -> [Int]
                    -> Certificates
                    -> SubExp
                    -> [SOACS.HistOp SOACS]
                    -> Lambda lore
                    -> [VName]
                    -> DistNestT lore m (Stms lore)
segmentedHistKernel :: KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
segmentedHistKernel KernelNest
nest [Int]
perm Certificates
cs SubExp
hist_w [HistOp SOACS]
ops Lambda lore
lam [VName]
arrs = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a Hist is not a reduction or
  -- scan.
  ([(VName, SubExp)]
ispace, [KernelInput]
inputs) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let orig_pat :: PatternT Type
orig_pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
                 PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

  -- The input/output arrays _must_ correspond to some kernel input,
  -- or else the original nested Hist would have been ill-typed.
  -- Find them.
  [HistOp SOACS]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> DistNestT lore m (HistOp SOACS))
 -> DistNestT lore m [HistOp SOACS])
-> (HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) ->
    SubExp -> SubExp -> [VName] -> Result -> Lambda -> HistOp SOACS
forall lore.
SubExp -> SubExp -> [VName] -> Result -> Lambda lore -> HistOp lore
SOACS.HistOp SubExp
num_bins SubExp
rf
    ([VName] -> Result -> Lambda -> HistOp SOACS)
-> DistNestT lore m [VName]
-> DistNestT lore m (Result -> Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> DistNestT lore m VName)
-> [VName] -> DistNestT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((KernelInput -> VName)
-> DistNestT lore m KernelInput -> DistNestT lore m VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KernelInput -> VName
kernelInputArray (DistNestT lore m KernelInput -> DistNestT lore m VName)
-> (VName -> DistNestT lore m KernelInput)
-> VName
-> DistNestT lore m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [KernelInput] -> VName -> DistNestT lore m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
inputs) [VName]
dests
    DistNestT lore m (Result -> Lambda -> HistOp SOACS)
-> DistNestT lore m Result
-> DistNestT lore m (Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistNestT lore m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
nes
    DistNestT lore m (Lambda -> HistOp SOACS)
-> DistNestT lore m Lambda -> DistNestT lore m (HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda -> DistNestT lore m Lambda
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda
op

  Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
 -> Result
 -> String
 -> ThreadRecommendation
 -> BinderT lore m (SegOpLevel lore))
-> DistNestT
     lore
     m
     (Result
      -> String
      -> ThreadRecommendation
      -> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
  Scope lore
scope <- DistNestT lore m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Lambda -> Binder lore (Lambda lore)
onLambda <- (DistEnv lore m -> Lambda -> Binder lore (Lambda lore))
-> DistNestT lore m (Lambda -> Binder lore (Lambda lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
forall lore (m :: * -> *).
DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda
  let onLambda' :: Lambda -> BinderT lore m (Lambda lore)
onLambda' = ((Lambda lore, Stms lore) -> Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore)
-> BinderT lore m (Lambda lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda lore, Stms lore) -> Lambda lore
forall a b. (a, b) -> a
fst (BinderT lore m (Lambda lore, Stms lore)
 -> BinderT lore m (Lambda lore))
-> (Lambda -> BinderT lore m (Lambda lore, Stms lore))
-> Lambda
-> BinderT lore m (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder lore (Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Lambda lore)
 -> BinderT lore m (Lambda lore, Stms lore))
-> (Lambda -> Binder lore (Lambda lore))
-> Lambda
-> BinderT lore m (Lambda lore, Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Binder lore (Lambda lore)
onLambda
  m (Stms lore) -> DistNestT lore m (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms lore) -> DistNestT lore m (Stms lore))
-> m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ (BinderT lore m () -> Scope lore -> m (Stms lore))
-> Scope lore -> BinderT lore m () -> m (Stms lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT lore m () -> Scope lore -> m (Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (Stms lore)
runBinderT_ Scope lore
scope (BinderT lore m () -> m (Stms lore))
-> BinderT lore m () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    -- It is important not to launch unnecessarily many threads for
    -- histograms, because it may mean we unnecessarily need to reduce
    -- subhistograms as well.
    SegOpLevel lore
lvl <- Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
hist_w SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"seghist" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
    Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      (Lambda -> BinderT lore m (Lambda (Lore (BinderT lore m))))
-> SegOpLevel (Lore (BinderT lore m))
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore (BinderT lore m))
-> [VName]
-> BinderT lore m (Stms (Lore (BinderT lore m)))
forall (m :: * -> *).
(MonadBinder m, DistLore (Lore m)) =>
(Lambda -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore m)
-> [VName]
-> m (Stms (Lore m))
histKernel Lambda -> BinderT lore m (Lambda lore)
Lambda -> BinderT lore m (Lambda (Lore (BinderT lore m)))
onLambda' SegOpLevel lore
SegOpLevel (Lore (BinderT lore m))
lvl PatternT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops' Lambda lore
Lambda (Lore (BinderT lore m))
lam [VName]
arrs
  where findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
          m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
        bad :: a
bad = String -> a
forall a. HasCallStack => String -> a
error String
"Ill-typed nested Hist encountered."

histKernel :: (MonadBinder m, DistLore (Lore m)) =>
              (Lambda SOACS -> m (Lambda (Lore m)))
           -> SegOpLevel (Lore m)
           -> PatternT Type -> [(VName, SubExp)] -> [KernelInput]
           -> Certificates -> SubExp -> [SOACS.HistOp SOACS]
           -> Lambda (Lore m) -> [VName]
           -> m (Stms (Lore m))
histKernel :: (Lambda -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore m)
-> [VName]
-> m (Stms (Lore m))
histKernel Lambda -> m (Lambda (Lore m))
onLambda SegOpLevel (Lore m)
lvl PatternT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops Lambda (Lore m)
lam [VName]
arrs = BinderT (Lore m) m () -> m (Stms (Lore m))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT (Lore m) m () -> m (Stms (Lore m)))
-> BinderT (Lore m) m () -> m (Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
  [HistOp (Lore m)]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
 -> BinderT (Lore m) m [HistOp (Lore m)])
-> (HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) -> do
    (Lambda
op', Result
nes', Shape
shape) <- Lambda -> Result -> BinderT (Lore m) m (Lambda, Result, Shape)
forall (m :: * -> *) lore.
(MonadBinder m, Lore m ~ lore) =>
Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
op Result
nes
    Lambda (Lore m)
op'' <- m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m)))
-> m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m))
forall a b. (a -> b) -> a -> b
$ Lambda -> m (Lambda (Lore m))
onLambda Lambda
op'
    HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m)))
-> HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m))
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> Lambda (Lore m)
-> HistOp (Lore m)
forall lore.
SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> Lambda lore
-> HistOp lore
HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes' Shape
shape Lambda (Lore m)
op''

  let isDest :: VName -> Bool
isDest = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem ([VName] -> VName -> Bool) -> [VName] -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> [VName]) -> [HistOp (Lore m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp (Lore m)]
ops'
      inputs' :: [KernelInput]
inputs' = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (KernelInput -> Bool) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isDest (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputArray) [KernelInput]
inputs

  Certificates -> BinderT (Lore m) m () -> BinderT (Lore m) m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT (Lore m) m () -> BinderT (Lore m) m ())
-> BinderT (Lore m) m () -> BinderT (Lore m) m ()
forall a b. (a -> b) -> a -> b
$
    Stms (Lore m) -> BinderT (Lore m) m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore m) -> BinderT (Lore m) m ())
-> BinderT (Lore m) m (Stms (Lore m)) -> BinderT (Lore m) m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm (Lore m) -> BinderT (Lore m) m (Stm (Lore m)))
-> Stms (Lore m) -> BinderT (Lore m) m (Stms (Lore m))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm (Lore m) -> BinderT (Lore m) m (Stm (Lore m))
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms (Lore m) -> BinderT (Lore m) m (Stms (Lore m)))
-> BinderT (Lore m) m (Stms (Lore m))
-> BinderT (Lore m) m (Stms (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    SegOpLevel (Lore m)
-> Pattern (Lore m)
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp (Lore m)]
-> Lambda (Lore m)
-> [VName]
-> BinderT (Lore m) m (Stms (Lore m))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
segHist SegOpLevel (Lore m)
lvl PatternT Type
Pattern (Lore m)
orig_pat SubExp
hist_w [(VName, SubExp)]
ispace [KernelInput]
inputs' [HistOp (Lore m)]
ops' Lambda (Lore m)
lam [VName]
arrs

determineReduceOp :: (MonadBinder m, Lore m ~ lore) =>
                     Lambda SOACS -> [SubExp]
                  -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp :: Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
lam Result
nes =
  -- FIXME? We are assuming that the accumulator is a replicate, and
  -- we fish out its value in a gross way.
  case (SubExp -> Maybe VName) -> Result -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar Result
nes of
    Just [VName]
ne_vs' -> do
      let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
      Result
nes' <- [VName] -> (VName -> m SubExp) -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ne_vs' ((VName -> m SubExp) -> m Result)
-> (VName -> m SubExp) -> m Result
forall a b. (a -> b) -> a -> b
$ \VName
ne_v -> do
        Type
ne_v_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
ne_v
        String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"hist_ne" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
ne_v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
ne_v_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
          Int -> DimIndex SubExp -> Slice SubExp
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (DimIndex SubExp -> Slice SubExp)
-> DimIndex SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0
      (Lambda, Result, Shape) -> m (Lambda, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam', Result
nes', Shape
shape)

    Maybe [VName]
Nothing ->
      (Lambda, Result, Shape) -> m (Lambda, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam, Result
nes, Shape
forall a. Monoid a => a
mempty)

isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap :: Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
  | [Let (Pattern [] [PatElemT (LetAttr SOACS)]
pes) StmAux (ExpAttr SOACS)
_ (Op (Screma w form arrs))] <-
      Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam,
    Body SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
[PatElemT (LetAttr SOACS)]
pes,
    Just Lambda
map_lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
    [VName]
arrs [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) =
      let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
map_lam
      in (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape, Lambda
lam')
  | Bool
otherwise = (Shape
forall a. Monoid a => a
mempty, Lambda
lam)

segmentedScanomapKernel :: (MonadFreshNames m, DistLore lore) =>
                           KernelNest
                        -> [Int]
                        -> SubExp
                        -> Lambda lore -> Lambda lore
                        -> [SubExp] -> [VName]
                        -> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel :: KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel KernelNest
nest [Int]
perm SubExp
segment_size Lambda lore
lam Lambda lore
map_lam Result
nes [VName]
arrs = do
  Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
 -> Result
 -> String
 -> ThreadRecommendation
 -> BinderT lore m (SegOpLevel lore))
-> DistNestT
     lore
     m
     (Result
      -> String
      -> ThreadRecommendation
      -> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
  KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm SubExp
segment_size (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
lam) (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
map_lam) Result
nes [VName]
arrs ((PatternT Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> Result
  -> [VName]
  -> [VName]
  -> BinderT lore m ())
 -> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
    \PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ [VName]
_ -> do
    let scan_op :: SegBinOp lore
scan_op = Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
Noncommutative Lambda lore
lam Result
nes' Shape
forall a. Monoid a => a
mempty
    SegOpLevel lore
lvl <- Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segscan" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
    Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm lore -> BinderT lore m (Stm lore))
-> Stms lore -> BinderT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> BinderT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> BinderT lore m (Stms lore))
-> BinderT lore m (Stms lore) -> BinderT lore m (Stms lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segScan SegOpLevel lore
lvl PatternT Type
Pattern lore
pat SubExp
segment_size [SegBinOp lore
scan_op] Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

regularSegmentedRedomapKernel :: (MonadFreshNames m, DistLore lore) =>
                                 KernelNest
                              -> [Int]
                              -> SubExp -> Commutativity
                              -> Lambda lore -> Lambda lore
                              -> [SubExp] -> [VName]
                              -> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel :: KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel KernelNest
nest [Int]
perm SubExp
segment_size Commutativity
comm Lambda lore
lam Lambda lore
map_lam Result
nes [VName]
arrs = do
  Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
 -> Result
 -> String
 -> ThreadRecommendation
 -> BinderT lore m (SegOpLevel lore))
-> DistNestT
     lore
     m
     (Result
      -> String
      -> ThreadRecommendation
      -> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
  KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm SubExp
segment_size (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
lam) (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
map_lam) Result
nes [VName]
arrs ((PatternT Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> Result
  -> [VName]
  -> [VName]
  -> BinderT lore m ())
 -> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
    \PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ [VName]
_ -> do
      let red_op :: SegBinOp lore
red_op = Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda lore
lam Result
nes' Shape
forall a. Monoid a => a
mempty
      SegOpLevel lore
lvl <- Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segred" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm lore -> BinderT lore m (Stm lore))
-> Stms lore -> BinderT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> BinderT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> BinderT lore m (Stms lore))
-> BinderT lore m (Stms lore) -> BinderT lore m (Stms lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
        SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segRed SegOpLevel lore
lvl PatternT Type
Pattern lore
pat SubExp
segment_size [SegBinOp lore
red_op] Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

isSegmentedOp :: (MonadFreshNames m, DistLore lore) =>
                 KernelNest
              -> [Int]
              -> SubExp
              -> Names -> Names
              -> [SubExp] -> [VName]
              -> (PatternT Type
                  -> [(VName, SubExp)]
                  -> [KernelInput]
                  -> [SubExp] -> [VName]  -> [VName]
                  -> BinderT lore m ())
              -> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp :: KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm SubExp
segment_size Names
free_in_op Names
_free_in_fold_op Result
nes [VName]
arrs PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ()
m = MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT (DistNestT lore m) (Stms lore)
 -> DistNestT lore m (Maybe (Stms lore)))
-> MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$ do
  -- We must verify that array inputs to the operation are inputs to
  -- the outermost loop nesting or free in the loop nest.  Nothing
  -- free in the op may be bound by the nest.  Furthermore, the
  -- neutral elements must be free in the loop nest.
  --
  -- We must summarise any names from free_in_op that are bound in the
  -- nest, and describe how to obtain them given segment indices.

  let bound_by_nest :: Names
bound_by_nest = KernelNest -> Names
boundInKernelNest KernelNest
nest

  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest
-> MaybeT (DistNestT lore m) ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest

  Bool
-> MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Names
free_in_op Names -> Names -> Bool
`namesIntersect` Names
bound_by_nest) (MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ())
-> MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ()
forall a b. (a -> b) -> a -> b
$
    String -> MaybeT (DistNestT lore m) ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Non-fold lambda uses nest-bound parameters."

  let indices :: [VName]
indices = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace

      prepareNe :: SubExp -> MaybeT (DistNestT lore m) SubExp
prepareNe (Var VName
v) | VName
v VName -> Names -> Bool
`nameIn` Names
bound_by_nest =
                            String -> MaybeT (DistNestT lore m) SubExp
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Neutral element bound in nest"
      prepareNe SubExp
ne = SubExp -> MaybeT (DistNestT lore m) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
ne

      prepareArr :: VName -> MaybeT (DistNestT lore m) (BinderT lore m VName)
prepareArr VName
arr =
        case (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps of
          Just KernelInput
inp
            | KernelInput -> Result
kernelInputIndices KernelInput
inp Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
indices ->
                BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
 -> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$ VName -> BinderT lore m VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> BinderT lore m VName) -> VName -> BinderT lore m VName
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
            | Bool -> Bool
not (KernelInput -> VName
kernelInputArray KernelInput
inp VName -> Names -> Bool
`nameIn` Names
bound_by_nest) ->
                BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
 -> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> KernelInput -> BinderT lore m VName
forall (m :: * -> *).
MonadBinder m =>
[(VName, SubExp)] -> KernelInput -> m VName
replicateMissing [(VName, SubExp)]
ispace KernelInput
inp
          Maybe KernelInput
Nothing | Bool -> Bool
not (VName
arr VName -> Names -> Bool
`nameIn` Names
bound_by_nest) ->
                      -- This input is something that is free inside
                      -- the loop nesting. We will have to replicate
                      -- it.
                      BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
 -> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$
                      String -> Exp (Lore (BinderT lore m)) -> BinderT lore m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_repd")
                      (BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
          Maybe KernelInput
_ ->
            String -> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Input not free or outermost."

  Result
nes' <- (SubExp -> MaybeT (DistNestT lore m) SubExp)
-> Result -> MaybeT (DistNestT lore m) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> MaybeT (DistNestT lore m) SubExp
prepareNe Result
nes

  [BinderT lore m VName]
mk_arrs <- (VName -> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> [VName] -> MaybeT (DistNestT lore m) [BinderT lore m VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> MaybeT (DistNestT lore m) (BinderT lore m VName)
prepareArr [VName]
arrs
  Scope lore
scope <- DistNestT lore m (Scope lore)
-> MaybeT (DistNestT lore m) (Scope lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift DistNestT lore m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope

  DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT lore m (Stms lore)
 -> MaybeT (DistNestT lore m) (Stms lore))
-> DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore)
forall a b. (a -> b) -> a -> b
$ m (Stms lore) -> DistNestT lore m (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms lore) -> DistNestT lore m (Stms lore))
-> m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ (BinderT lore m () -> Scope lore -> m (Stms lore))
-> Scope lore -> BinderT lore m () -> m (Stms lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT lore m () -> Scope lore -> m (Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (Stms lore)
runBinderT_ Scope lore
scope (BinderT lore m () -> m (Stms lore))
-> BinderT lore m () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
    -- We must make sure all inputs are of size
    -- segment_size*nesting_size.
    SubExp
total_num_elements <-
      String -> Exp (Lore (BinderT lore m)) -> BinderT lore m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"total_num_elements" (ExpT lore -> BinderT lore m SubExp)
-> BinderT lore m (ExpT lore) -> BinderT lore m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      BinOp
-> SubExp -> Result -> BinderT lore m (Exp (Lore (BinderT lore m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
segment_size (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)

    let flatten :: VName -> BinderT lore m VName
flatten VName
arr = do
          Shape
arr_shape <- Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> BinderT lore m Type -> BinderT lore m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> BinderT lore m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          -- CHECKME: is the length the right thing here?  We want to
          -- reproduce the parameter type.
          let reshape :: ShapeChange SubExp
reshape = ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter [SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew SubExp
total_num_elements]
                        (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
+[LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest)) Shape
arr_shape
          String -> Exp (Lore (BinderT lore m)) -> BinderT lore m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_flat") (Exp (Lore (BinderT lore m)) -> BinderT lore m VName)
-> Exp (Lore (BinderT lore m)) -> BinderT lore m VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
reshape VName
arr

    [VName]
nested_arrs <- [BinderT lore m VName] -> BinderT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [BinderT lore m VName]
mk_arrs
    [VName]
arrs' <- (VName -> BinderT lore m VName)
-> [VName] -> BinderT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT lore m VName
flatten [VName]
nested_arrs

    let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
              PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ()
m PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
kernel_inps Result
nes' [VName]
nested_arrs [VName]
arrs'

  where replicateMissing :: [(VName, SubExp)] -> KernelInput -> m VName
replicateMissing [(VName, SubExp)]
ispace KernelInput
inp = do
          Type
t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType (VName -> m Type) -> VName -> m Type
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
          let inp_is :: Result
inp_is = KernelInput -> Result
kernelInputIndices KernelInput
inp
              shapes :: [Shape]
shapes = [(VName, SubExp)] -> Result -> [Shape]
forall d. [(VName, d)] -> Result -> [ShapeBase d]
determineRepeats [(VName, SubExp)]
ispace Result
inp_is
              ([Shape]
outer_shapes, Shape
inner_shape) = [Shape] -> Type -> ([Shape], Shape)
repeatShapes [Shape]
shapes Type
t
          String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"repeated" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
            [Shape] -> Shape -> VName -> BasicOp
Repeat [Shape]
outer_shapes Shape
inner_shape (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp

        determineRepeats :: [(VName, d)] -> Result -> [ShapeBase d]
determineRepeats [(VName, d)]
ispace (SubExp
i:Result
is)
          | ([(VName, d)]
skipped_ispace, [(VName, d)]
ispace') <- ((VName, d) -> Bool)
-> [(VName, d)] -> ([(VName, d)], [(VName, d)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span ((SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/=SubExp
i) (SubExp -> Bool) -> ((VName, d) -> SubExp) -> (VName, d) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> ((VName, d) -> VName) -> (VName, d) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, d) -> VName
forall a b. (a, b) -> a
fst) [(VName, d)]
ispace =
              [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape (((VName, d) -> d) -> [(VName, d)] -> [d]
forall a b. (a -> b) -> [a] -> [b]
map (VName, d) -> d
forall a b. (a, b) -> b
snd [(VName, d)]
skipped_ispace) ShapeBase d -> [ShapeBase d] -> [ShapeBase d]
forall a. a -> [a] -> [a]
: [(VName, d)] -> Result -> [ShapeBase d]
determineRepeats (Int -> [(VName, d)] -> [(VName, d)]
forall a. Int -> [a] -> [a]
drop Int
1 [(VName, d)]
ispace') Result
is
        determineRepeats [(VName, d)]
ispace Result
_ =
          [[d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape ([d] -> ShapeBase d) -> [d] -> ShapeBase d
forall a b. (a -> b) -> a -> b
$ ((VName, d) -> d) -> [(VName, d)] -> [d]
forall a b. (a -> b) -> [a] -> [b]
map (VName, d) -> d
forall a b. (a, b) -> b
snd [(VName, d)]
ispace]

permutationAndMissing :: PatternT Type -> [SubExp] -> Maybe ([Int], [PatElemT Type])
permutationAndMissing :: PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
pat Result
res = do
  let pes :: [PatElemT Type]
pes = PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT Type
pat
      ([PatElemT Type]
_used,[PatElemT Type]
unused) =
        (PatElemT Type -> Bool)
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> Names -> Bool
`nameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res) (VName -> Bool)
-> (PatElemT Type -> VName) -> PatElemT Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
pes
      res_expanded :: Result
res_expanded = Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
unused
  [Int]
perm <- (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
pes Result -> Result -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` Result
res_expanded
  ([Int], [PatElemT Type]) -> Maybe ([Int], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [PatElemT Type]
unused)

-- Add extra pattern elements to every kernel nesting level.
expandKernelNest :: MonadFreshNames m =>
                    [PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest :: [PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pes (LoopNesting
outer_nest, [LoopNesting]
inner_nests) = do
  let outer_size :: Result
outer_size = LoopNesting -> SubExp
loopNestingWidth LoopNesting
outer_nest SubExp -> Result -> Result
forall a. a -> [a] -> [a]
:
                   (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
      inner_sizes :: [Result]
inner_sizes = Result -> [Result]
forall a. [a] -> [[a]]
tails (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
  LoopNesting
outer_nest' <- LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
outer_nest Result
outer_size
  [LoopNesting]
inner_nests' <- (LoopNesting -> Result -> m LoopNesting)
-> [LoopNesting] -> [Result] -> m [LoopNesting]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM LoopNesting -> Result -> m LoopNesting
expandWith [LoopNesting]
inner_nests [Result]
inner_sizes
  KernelNest -> m KernelNest
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopNesting
outer_nest', [LoopNesting]
inner_nests')
  where expandWith :: LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
nest Result
dims = do
           [PatElemT Type]
pes' <- (PatElemT Type -> m (PatElemT Type))
-> [PatElemT Type] -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Result -> PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) attr.
(MonadFreshNames m, Typed attr) =>
Result -> PatElemT attr -> m (PatElemT Type)
expandPatElemWith Result
dims) [PatElemT Type]
pes
           LoopNesting -> m LoopNesting
forall (m :: * -> *) a. Monad m => a -> m a
return LoopNesting
nest { loopNestingPattern :: PatternT Type
loopNestingPattern =
                           [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
                           PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements (LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
nest) [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type]
pes'
                       }

        expandPatElemWith :: Result -> PatElemT attr -> m (PatElemT Type)
expandPatElemWith Result
dims PatElemT attr
pe = do
          VName
name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
pe
          PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT attr
pe { patElemName :: VName
patElemName = VName
name
                    , patElemAttr :: Type
patElemAttr = PatElemT attr -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT attr
pe Type -> Shape -> Type
`arrayOfShape` Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
dims
                    }

kernelOrNot :: (MonadFreshNames m, DistLore lore) =>
               Certificates -> Stm SOACS -> DistAcc lore
            -> PostStms lore -> DistAcc lore -> Maybe (Stms lore)
            -> DistNestT lore m (DistAcc lore)
kernelOrNot :: Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
_ DistAcc lore
_ Maybe (Stms lore)
Nothing =
  Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs Stm SOACS
bnd) DistAcc lore
acc
kernelOrNot Certificates
cs Stm SOACS
_ DistAcc lore
_ PostStms lore
kernels DistAcc lore
acc' (Just Stms lore
bnds) = do
  PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
  Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> Stms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Stm lore) -> Stms lore -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms lore
bnds
  DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'

distributeMap :: (MonadFreshNames m, DistLore lore) =>
                 MapLoop -> DistAcc lore
              -> DistNestT lore m (DistAcc lore)
distributeMap :: MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distributeMap (MapLoop Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs) DistAcc lore
acc =
  DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
  DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
MonadFreshNames m =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
  PatternT Type
-> Certificates
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore a.
(Monad m, DistLore lore) =>
PatternT Type
-> Certificates
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m a
-> DistNestT lore m a
mapNesting PatternT Type
Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs
  (DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc' Stms SOACS
lam_bnds)

  where acc' :: DistAcc lore
acc' = DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc { distTargets :: Targets
distTargets = (PatternT Type, Result) -> Targets -> Targets
pushInnerTarget
                                       (PatternT Type
Pattern
pat, Body SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) (Targets -> Targets) -> Targets -> Targets
forall a b. (a -> b) -> a -> b
$
                                       DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc
                       , distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
                       }

        lam_bnds :: Stms SOACS
lam_bnds = Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam