{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}

module Futhark.Pass.ExtractKernels.DistributeNests
  ( MapLoop (..),
    mapLoopStm,
    bodyContainsParallelism,
    lambdaContainsParallelism,
    determineReduceOp,
    histKernel,
    DistEnv (..),
    DistAcc (..),
    runDistNestT,
    DistNestT,
    liftInner,
    distributeMap,
    distribute,
    distributeSingleStm,
    distributeMapBodyStms,
    addStmsToAcc,
    addStmToAcc,
    permutationAndMissing,
    addPostStms,
    postStm,
    inNesting,
  )
where

import Control.Arrow (first)
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Trans.Maybe
import Control.Monad.Writer.Strict
import Data.Function ((&))
import Data.List (find, partition, tails)
import qualified Data.Map as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.SOACS (SOACS)
import qualified Futhark.IR.SOACS as SOACS
import Futhark.IR.SOACS.SOAC hiding (HistOp, histDest)
import Futhark.IR.SOACS.Simplify (simpleSOACS, simplifyStms)
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Tools
import Futhark.Transform.CopyPropagate
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename
import Futhark.Util
import Futhark.Util.Log

scopeForSOACs :: SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs :: Scope rep -> Scope SOACS
scopeForSOACs = Scope rep -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope

data MapLoop = MapLoop SOACS.Pat (StmAux ()) SubExp SOACS.Lambda [VName]

mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm (MapLoop Pat
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) =
  Pat -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> ExpT rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda
lam

data DistEnv rep m = DistEnv
  { DistEnv rep m -> Nestings
distNest :: Nestings,
    DistEnv rep m -> Scope rep
distScope :: Scope rep,
    DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
distOnTopLevelStms :: Stms SOACS -> DistNestT rep m (Stms rep),
    DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distOnInnerMap ::
      MapLoop ->
      DistAcc rep ->
      DistNestT rep m (DistAcc rep),
    DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
distOnSOACSStms :: Stm SOACS -> Builder rep (Stms rep),
    DistEnv rep m -> Lambda -> Builder rep (Lambda rep)
distOnSOACSLambda :: Lambda SOACS -> Builder rep (Lambda rep),
    DistEnv rep m -> MkSegLevel rep m
distSegLevel :: MkSegLevel rep m
  }

data DistAcc rep = DistAcc
  { DistAcc rep -> Targets
distTargets :: Targets,
    DistAcc rep -> Stms rep
distStms :: Stms rep
  }

data DistRes rep = DistRes
  { DistRes rep -> PostStms rep
accPostStms :: PostStms rep,
    DistRes rep -> Log
accLog :: Log
  }

instance Semigroup (DistRes rep) where
  DistRes PostStms rep
ks1 Log
log1 <> :: DistRes rep -> DistRes rep -> DistRes rep
<> DistRes PostStms rep
ks2 Log
log2 =
    PostStms rep -> Log -> DistRes rep
forall rep. PostStms rep -> Log -> DistRes rep
DistRes (PostStms rep
ks1 PostStms rep -> PostStms rep -> PostStms rep
forall a. Semigroup a => a -> a -> a
<> PostStms rep
ks2) (Log
log1 Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log2)

instance Monoid (DistRes rep) where
  mempty :: DistRes rep
mempty = PostStms rep -> Log -> DistRes rep
forall rep. PostStms rep -> Log -> DistRes rep
DistRes PostStms rep
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

newtype PostStms rep = PostStms {PostStms rep -> Stms rep
unPostStms :: Stms rep}

instance Semigroup (PostStms rep) where
  PostStms Stms rep
xs <> :: PostStms rep -> PostStms rep -> PostStms rep
<> PostStms Stms rep
ys = Stms rep -> PostStms rep
forall rep. Stms rep -> PostStms rep
PostStms (Stms rep -> PostStms rep) -> Stms rep -> PostStms rep
forall a b. (a -> b) -> a -> b
$ Stms rep
ys Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
xs

instance Monoid (PostStms rep) where
  mempty :: PostStms rep
mempty = Stms rep -> PostStms rep
forall rep. Stms rep -> PostStms rep
PostStms Stms rep
forall a. Monoid a => a
mempty

typeEnvFromDistAcc :: DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc :: DistAcc rep -> Scope rep
typeEnvFromDistAcc = PatT Type -> Scope rep
forall rep dec. (LetDec rep ~ dec) => PatT dec -> Scope rep
scopeOfPat (PatT Type -> Scope rep)
-> (DistAcc rep -> PatT Type) -> DistAcc rep -> Scope rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatT Type, Result) -> PatT Type
forall a b. (a, b) -> a
fst ((PatT Type, Result) -> PatT Type)
-> (DistAcc rep -> (PatT Type, Result)) -> DistAcc rep -> PatT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (PatT Type, Result)
outerTarget (Targets -> (PatT Type, Result))
-> (DistAcc rep -> Targets) -> DistAcc rep -> (PatT Type, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets

addStmsToAcc :: Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc :: Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc Stms rep
stms DistAcc rep
acc =
  DistAcc rep
acc {distStms :: Stms rep
distStms = Stms rep
stms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> DistAcc rep -> Stms rep
forall rep. DistAcc rep -> Stms rep
distStms DistAcc rep
acc}

addStmToAcc ::
  (MonadFreshNames m, DistRep rep) =>
  Stm SOACS ->
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
addStmToAcc :: Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc = do
  Stm SOACS -> Builder rep (Stms rep)
onSoacs <- (DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep))
-> DistNestT rep m (Stm SOACS -> Builder rep (Stms rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
forall rep (m :: * -> *).
DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
distOnSOACSStms
  (Stms rep
stm', Stms rep
_) <- Builder rep (Stms rep) -> DistNestT rep m (Stms rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Stms rep) -> DistNestT rep m (Stms rep, Stms rep))
-> Builder rep (Stms rep) -> DistNestT rep m (Stms rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Builder rep (Stms rep)
onSoacs Stm SOACS
stm
  DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc {distStms :: Stms rep
distStms = Stms rep
stm' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> DistAcc rep -> Stms rep
forall rep. DistAcc rep -> Stms rep
distStms DistAcc rep
acc}

soacsLambda ::
  (MonadFreshNames m, DistRep rep) =>
  Lambda SOACS ->
  DistNestT rep m (Lambda rep)
soacsLambda :: Lambda -> DistNestT rep m (Lambda rep)
soacsLambda Lambda
lam = do
  Lambda -> Builder rep (Lambda rep)
onLambda <- (DistEnv rep m -> Lambda -> Builder rep (Lambda rep))
-> DistNestT rep m (Lambda -> Builder rep (Lambda rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Lambda -> Builder rep (Lambda rep)
forall rep (m :: * -> *).
DistEnv rep m -> Lambda -> Builder rep (Lambda rep)
distOnSOACSLambda
  (Lambda rep, Stms rep) -> Lambda rep
forall a b. (a, b) -> a
fst ((Lambda rep, Stms rep) -> Lambda rep)
-> DistNestT rep m (Lambda rep, Stms rep)
-> DistNestT rep m (Lambda rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Builder rep (Lambda rep) -> DistNestT rep m (Lambda rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Lambda -> Builder rep (Lambda rep)
onLambda Lambda
lam)

newtype DistNestT rep m a
  = DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a)
  deriving
    ( a -> DistNestT rep m b -> DistNestT rep m a
(a -> b) -> DistNestT rep m a -> DistNestT rep m b
(forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b)
-> (forall a b. a -> DistNestT rep m b -> DistNestT rep m a)
-> Functor (DistNestT rep m)
forall a b. a -> DistNestT rep m b -> DistNestT rep m a
forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall rep (m :: * -> *) a b.
Functor m =>
a -> DistNestT rep m b -> DistNestT rep m a
forall rep (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DistNestT rep m b -> DistNestT rep m a
$c<$ :: forall rep (m :: * -> *) a b.
Functor m =>
a -> DistNestT rep m b -> DistNestT rep m a
fmap :: (a -> b) -> DistNestT rep m a -> DistNestT rep m b
$cfmap :: forall rep (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT rep m a -> DistNestT rep m b
Functor,
      Functor (DistNestT rep m)
a -> DistNestT rep m a
Functor (DistNestT rep m)
-> (forall a. a -> DistNestT rep m a)
-> (forall a b.
    DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b)
-> (forall a b c.
    (a -> b -> c)
    -> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c)
-> (forall a b.
    DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b)
-> (forall a b.
    DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a)
-> Applicative (DistNestT rep m)
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
forall a. a -> DistNestT rep m a
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall a b.
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall a b c.
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
forall rep (m :: * -> *).
Applicative m =>
Functor (DistNestT rep m)
forall rep (m :: * -> *) a. Applicative m => a -> DistNestT rep m a
forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall rep (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
$c<* :: forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
*> :: DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
$c*> :: forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
liftA2 :: (a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
$cliftA2 :: forall rep (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
<*> :: DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
$c<*> :: forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
pure :: a -> DistNestT rep m a
$cpure :: forall rep (m :: * -> *) a. Applicative m => a -> DistNestT rep m a
$cp1Applicative :: forall rep (m :: * -> *).
Applicative m =>
Functor (DistNestT rep m)
Applicative,
      Applicative (DistNestT rep m)
a -> DistNestT rep m a
Applicative (DistNestT rep m)
-> (forall a b.
    DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b)
-> (forall a b.
    DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b)
-> (forall a. a -> DistNestT rep m a)
-> Monad (DistNestT rep m)
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall a. a -> DistNestT rep m a
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall rep (m :: * -> *). Monad m => Applicative (DistNestT rep m)
forall rep (m :: * -> *) a. Monad m => a -> DistNestT rep m a
forall rep (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall rep (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DistNestT rep m a
$creturn :: forall rep (m :: * -> *) a. Monad m => a -> DistNestT rep m a
>> :: DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
$c>> :: forall rep (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
>>= :: DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
$c>>= :: forall rep (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
$cp1Monad :: forall rep (m :: * -> *). Monad m => Applicative (DistNestT rep m)
Monad,
      MonadReader (DistEnv rep m),
      MonadWriter (DistRes rep)
    )

liftInner :: (LocalScope rep m, DistRep rep) => m a -> DistNestT rep m a
liftInner :: m a -> DistNestT rep m a
liftInner m a
m = do
  Scope rep
outer_scope <- DistNestT rep m (Scope rep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
forall rep (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
 -> DistNestT rep m a)
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
forall a b. (a -> b) -> a -> b
$
    WriterT (DistRes rep) m a
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes rep) m a
 -> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a)
-> WriterT (DistRes rep) m a
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
forall a b. (a -> b) -> a -> b
$
      m a -> WriterT (DistRes rep) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> WriterT (DistRes rep) m a)
-> m a -> WriterT (DistRes rep) m a
forall a b. (a -> b) -> a -> b
$ do
        Scope rep
inner_scope <- m (Scope rep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
        Scope rep -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep
outer_scope Scope rep -> Scope rep -> Scope rep
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Scope rep
inner_scope) m a
m

instance MonadFreshNames m => MonadFreshNames (DistNestT rep m) where
  getNameSource :: DistNestT rep m VNameSource
getNameSource = ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) VNameSource
-> DistNestT rep m VNameSource
forall rep (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) VNameSource
 -> DistNestT rep m VNameSource)
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) VNameSource
-> DistNestT rep m VNameSource
forall a b. (a -> b) -> a -> b
$ WriterT (DistRes rep) m VNameSource
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) VNameSource
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift WriterT (DistRes rep) m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  putNameSource :: VNameSource -> DistNestT rep m ()
putNameSource = ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ()
-> DistNestT rep m ()
forall rep (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ()
 -> DistNestT rep m ())
-> (VNameSource
    -> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ())
-> VNameSource
-> DistNestT rep m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT (DistRes rep) m ()
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes rep) m ()
 -> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ())
-> (VNameSource -> WriterT (DistRes rep) m ())
-> VNameSource
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> WriterT (DistRes rep) m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource

instance (Monad m, ASTRep rep) => HasScope rep (DistNestT rep m) where
  askScope :: DistNestT rep m (Scope rep)
askScope = (DistEnv rep m -> Scope rep) -> DistNestT rep m (Scope rep)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Scope rep
forall rep (m :: * -> *). DistEnv rep m -> Scope rep
distScope

instance (Monad m, ASTRep rep) => LocalScope rep (DistNestT rep m) where
  localScope :: Scope rep -> DistNestT rep m a -> DistNestT rep m a
localScope Scope rep
types = (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv rep m -> DistEnv rep m)
 -> DistNestT rep m a -> DistNestT rep m a)
-> (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a
-> DistNestT rep m a
forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
    DistEnv rep m
env {distScope :: Scope rep
distScope = Scope rep
types Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> DistEnv rep m -> Scope rep
forall rep (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env}

instance Monad m => MonadLogger (DistNestT rep m) where
  addLog :: Log -> DistNestT rep m ()
addLog Log
msgs = DistRes rep -> DistNestT rep m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell DistRes rep
forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
msgs}

runDistNestT ::
  (MonadLogger m, DistRep rep) =>
  DistEnv rep m ->
  DistNestT rep m (DistAcc rep) ->
  m (Stms rep)
runDistNestT :: DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv rep m
env (DistNestT ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
m) = do
  (DistAcc rep
acc, DistRes rep
res) <- WriterT (DistRes rep) m (DistAcc rep)
-> m (DistAcc rep, DistRes rep)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (DistRes rep) m (DistAcc rep)
 -> m (DistAcc rep, DistRes rep))
-> WriterT (DistRes rep) m (DistAcc rep)
-> m (DistAcc rep, DistRes rep)
forall a b. (a -> b) -> a -> b
$ ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
-> DistEnv rep m -> WriterT (DistRes rep) m (DistAcc rep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
m DistEnv rep m
env
  Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog (Log -> m ()) -> Log -> m ()
forall a b. (a -> b) -> a -> b
$ DistRes rep -> Log
forall rep. DistRes rep -> Log
accLog DistRes rep
res
  -- There may be a few final targets remaining - these correspond to
  -- arrays that are identity mapped, and must have statements
  -- inserted here.
  Stms rep -> m (Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> m (Stms rep)) -> Stms rep -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$
    PostStms rep -> Stms rep
forall rep. PostStms rep -> Stms rep
unPostStms (DistRes rep -> PostStms rep
forall rep. DistRes rep -> PostStms rep
accPostStms DistRes rep
res) Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> (PatT Type, Result) -> Stms rep
identityStms (Targets -> (PatT Type, Result)
outerTarget (Targets -> (PatT Type, Result)) -> Targets -> (PatT Type, Result)
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc)
  where
    outermost :: LoopNesting
outermost = Nesting -> LoopNesting
nestingLoop (Nesting -> LoopNesting) -> Nesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
      case DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env of
        (Nesting
nest, []) -> Nesting
nest
        (Nesting
_, Nesting
nest : [Nesting]
_) -> Nesting
nest
    params_to_arrs :: [(VName, VName)]
params_to_arrs =
      ((Param Type, VName) -> (VName, VName))
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param Type -> VName) -> (Param Type, VName) -> (VName, VName)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Param Type -> VName
forall dec. Param dec -> VName
paramName) ([(Param Type, VName)] -> [(VName, VName)])
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$
        LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outermost

    identityStms :: (PatT Type, Result) -> Stms rep
identityStms (PatT Type
rem_pat, Result
res) =
      [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm rep] -> Stms rep) -> [Stm rep] -> Stms rep
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> SubExpRes -> Stm rep)
-> [PatElemT Type] -> Result -> [Stm rep]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT Type -> SubExpRes -> Stm rep
identityStm (PatT Type -> [PatElemT Type]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT Type
rem_pat) Result
res
    identityStm :: PatElemT Type -> SubExpRes -> Stm rep
identityStm PatElemT Type
pe (SubExpRes Certs
cs (Var VName
v))
      | Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
        Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> Stm rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
    identityStm PatElemT Type
pe (SubExpRes Certs
cs SubExp
se) =
      Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$
        Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [LoopNesting -> SubExp
loopNestingWidth LoopNesting
outermost]) SubExp
se

addPostStms :: Monad m => PostStms rep -> DistNestT rep m ()
addPostStms :: PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
ks = DistRes rep -> DistNestT rep m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (DistRes rep -> DistNestT rep m ())
-> DistRes rep -> DistNestT rep m ()
forall a b. (a -> b) -> a -> b
$ DistRes Any
forall a. Monoid a => a
mempty {accPostStms :: PostStms rep
accPostStms = PostStms rep
ks}

postStm :: Monad m => Stms rep -> DistNestT rep m ()
postStm :: Stms rep -> DistNestT rep m ()
postStm Stms rep
stms = PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms (PostStms rep -> DistNestT rep m ())
-> PostStms rep -> DistNestT rep m ()
forall a b. (a -> b) -> a -> b
$ Stms rep -> PostStms rep
forall rep. Stms rep -> PostStms rep
PostStms Stms rep
stms

withStm ::
  (Monad m, DistRep rep) =>
  Stm SOACS ->
  DistNestT rep m a ->
  DistNestT rep m a
withStm :: Stm SOACS -> DistNestT rep m a -> DistNestT rep m a
withStm Stm SOACS
stm = (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv rep m -> DistEnv rep m)
 -> DistNestT rep m a -> DistNestT rep m a)
-> (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a
-> DistNestT rep m a
forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
  DistEnv rep m
env
    { distScope :: Scope rep
distScope =
        Scope SOACS -> Scope rep
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Stm SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stm SOACS
stm) Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> DistEnv rep m -> Scope rep
forall rep (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env,
      distNest :: Nestings
distNest =
        Names -> Nestings -> Nestings
letBindInInnerNesting Names
provided (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$
          DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env
    }
  where
    provided :: Names
provided = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT Type -> [VName]) -> PatT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pat
forall rep. Stm rep -> Pat rep
stmPat Stm SOACS
stm

leavingNesting ::
  (MonadFreshNames m, DistRep rep) =>
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
leavingNesting :: DistAcc rep -> DistNestT rep m (DistAcc rep)
leavingNesting DistAcc rep
acc =
  case Targets -> Maybe ((PatT Type, Result), Targets)
popInnerTarget (Targets -> Maybe ((PatT Type, Result), Targets))
-> Targets -> Maybe ((PatT Type, Result), Targets)
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc of
    Maybe ((PatT Type, Result), Targets)
Nothing ->
      [Char] -> DistNestT rep m (DistAcc rep)
forall a. HasCallStack => [Char] -> a
error [Char]
"The kernel targets list is unexpectedly small"
    Just ((PatT Type
pat, Result
res), Targets
newtargets)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Seq (Stm rep) -> Bool) -> Seq (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> Seq (Stm rep)
forall rep. DistAcc rep -> Stms rep
distStms DistAcc rep
acc -> do
        -- Any statements left over correspond to something that
        -- could not be distributed because it would cause irregular
        -- arrays.  These must be reconstructed into a a Map SOAC
        -- that will be sequentialised. XXX: life would be better if
        -- we were able to distribute irregular parallelism.
        (Nesting Names
_ LoopNesting
inner, [Nesting]
_) <- (DistEnv rep m -> Nestings) -> DistNestT rep m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest
        let MapNesting PatT Type
_ StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs = LoopNesting
inner
            body :: BodyT rep
body = BodyDec rep -> Seq (Stm rep) -> Result -> BodyT rep
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () (DistAcc rep -> Seq (Stm rep)
forall rep. DistAcc rep -> Stms rep
distStms DistAcc rep
acc) Result
res
            used_in_body :: Names
used_in_body = BodyT rep -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT rep
body
            ([Param Type]
used_params, [VName]
used_arrs) =
              [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
                ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used_in_body) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
            lam' :: LambdaT rep
lam' =
              Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
                { lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
[LParam rep]
used_params,
                  lambdaBody :: BodyT rep
lambdaBody = BodyT rep
body,
                  lambdaReturnType :: [Type]
lambdaReturnType = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatT Type -> [Type]
forall dec. Typed dec => PatT dec -> [Type]
patTypes PatT Type
pat
                }
        Seq (Stm rep)
stms <-
          Builder rep () -> DistNestT rep m (Seq (Stm rep))
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> DistNestT rep m (Seq (Stm rep)))
-> (SOAC rep -> Builder rep ())
-> SOAC rep
-> DistNestT rep m (Seq (Stm rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux () -> Builder rep () -> Builder rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (Builder rep () -> Builder rep ())
-> (SOAC rep -> Builder rep ()) -> SOAC rep -> Builder rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (Rep (BuilderT rep (State VNameSource)))
-> SOAC (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
Transformer m =>
Pat (Rep m) -> SOAC (Rep m) -> m ()
FOT.transformSOAC PatT Type
Pat (Rep (BuilderT rep (State VNameSource)))
pat (SOAC rep -> DistNestT rep m (Seq (Stm rep)))
-> SOAC rep -> DistNestT rep m (Seq (Stm rep))
forall a b. (a -> b) -> a -> b
$
            SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
used_arrs (ScremaForm rep -> SOAC rep) -> ScremaForm rep -> SOAC rep
forall a b. (a -> b) -> a -> b
$ LambdaT rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
mapSOAC LambdaT rep
lam'

        DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ DistAcc rep
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Seq (Stm rep)
distStms = Seq (Stm rep)
stms}
      | Bool
otherwise -> do
        -- Any results left over correspond to a Replicate or a Copy in
        -- the parent nesting, depending on whether the argument is a
        -- parameter of the innermost nesting.
        (Nesting Names
_ LoopNesting
inner_nesting, [Nesting]
_) <- (DistEnv rep m -> Nestings) -> DistNestT rep m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest
        let w :: SubExp
w = LoopNesting -> SubExp
loopNestingWidth LoopNesting
inner_nesting
            aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
inner_nesting
            inps :: [(Param Type, VName)]
inps = LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
inner_nesting

            remnantStm :: PatElemT Type -> SubExpRes -> Stm rep
remnantStm PatElemT Type
pe (SubExpRes Certs
cs (Var VName
v))
              | Just (Param Type
_, VName
arr) <- ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
inps =
                Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> Stm rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT Type
pe]) StmAux ()
StmAux (ExpDec rep)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
            remnantStm PatElemT Type
pe (SubExpRes Certs
cs SubExp
se) =
              Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> Stm rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT Type
pe]) StmAux ()
StmAux (ExpDec rep)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se

            stms :: Seq (Stm rep)
stms =
              [Stm rep] -> Seq (Stm rep)
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm rep] -> Seq (Stm rep)) -> [Stm rep] -> Seq (Stm rep)
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> SubExpRes -> Stm rep)
-> [PatElemT Type] -> Result -> [Stm rep]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT Type -> SubExpRes -> Stm rep
remnantStm (PatT Type -> [PatElemT Type]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT Type
pat) Result
res

        DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ DistAcc rep
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Seq (Stm rep)
distStms = Seq (Stm rep)
stms}

mapNesting ::
  (MonadFreshNames m, DistRep rep) =>
  PatT Type ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  DistNestT rep m (DistAcc rep) ->
  DistNestT rep m (DistAcc rep)
mapNesting :: PatT Type
-> StmAux ()
-> SubExp
-> Lambda
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
mapNesting PatT Type
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs DistNestT rep m (DistAcc rep)
m =
  (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local DistEnv rep m -> DistEnv rep m
extend (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
leavingNesting (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistNestT rep m (DistAcc rep)
m
  where
    nest :: Nesting
nest =
      Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty (LoopNesting -> Nesting) -> LoopNesting -> Nesting
forall a b. (a -> b) -> a -> b
$
        PatT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatT Type
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$
          [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda
lam) [VName]
arrs
    extend :: DistEnv rep m -> DistEnv rep m
extend DistEnv rep m
env =
      DistEnv rep m
env
        { distNest :: Nestings
distNest = Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nest (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$ DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env,
          distScope :: Scope rep
distScope = Scope SOACS -> Scope rep
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Lambda -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda
lam) Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> DistEnv rep m -> Scope rep
forall rep (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env
        }

inNesting ::
  (Monad m, DistRep rep) =>
  KernelNest ->
  DistNestT rep m a ->
  DistNestT rep m a
inNesting :: KernelNest -> DistNestT rep m a -> DistNestT rep m a
inNesting (LoopNesting
outer, [LoopNesting]
nests) = (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv rep m -> DistEnv rep m)
 -> DistNestT rep m a -> DistNestT rep m a)
-> (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a
-> DistNestT rep m a
forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
  DistEnv rep m
env
    { distNest :: Nestings
distNest = (Nesting
inner, [Nesting]
nests'),
      distScope :: Scope rep
distScope = (LoopNesting -> Scope rep) -> [LoopNesting] -> Scope rep
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap LoopNesting -> Scope rep
forall rep. (LParamInfo rep ~ Type) => LoopNesting -> Scope rep
scopeOfLoopNesting (LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests) Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> DistEnv rep m -> Scope rep
forall rep (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env
    }
  where
    (Nesting
inner, [Nesting]
nests') =
      case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
        [] -> (LoopNesting -> Nesting
asNesting LoopNesting
outer, [])
        (LoopNesting
inner' : [LoopNesting]
ns) -> (LoopNesting -> Nesting
asNesting LoopNesting
inner', (LoopNesting -> Nesting) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Nesting
asNesting ([LoopNesting] -> [Nesting]) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> a -> b
$ LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
ns)
    asNesting :: LoopNesting -> Nesting
asNesting = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty

bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism = (Stm SOACS -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
isParallelStm (Stms SOACS -> Bool)
-> (Body SOACS -> Stms SOACS) -> Body SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms
  where
    isParallelStm :: Stm SOACS -> Bool
isParallelStm Stm SOACS
stm =
      Exp SOACS -> Bool
isMap (Stm SOACS -> Exp SOACS
forall rep. Stm rep -> Exp rep
stmExp Stm SOACS
stm)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm))
    isMap :: Exp SOACS -> Bool
isMap BasicOp {} = Bool
False
    isMap Apply {} = Bool
False
    isMap If {} = Bool
False
    isMap (DoLoop [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
body) = Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body
    isMap (DoLoop [(FParam SOACS, SubExp)]
_ WhileLoop {} Body SOACS
_) = Bool
False
    isMap (WithAcc [WithAccInput SOACS]
_ Lambda
lam) = Body SOACS -> Bool
bodyContainsParallelism (Body SOACS -> Bool) -> Body SOACS -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam
    isMap Op {} = Bool
True

lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism :: Lambda -> Bool
lambdaContainsParallelism = Body SOACS -> Bool
bodyContainsParallelism (Body SOACS -> Bool) -> (Lambda -> Body SOACS) -> Lambda -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Body SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody

distributeMapBodyStms ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  Stms SOACS ->
  DistNestT rep m (DistAcc rep)
distributeMapBodyStms :: DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
orig_acc = DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> (Stms SOACS -> DistNestT rep m (DistAcc rep))
-> Stms SOACS
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *).
(MonadFreshNames m, Buildable rep, HasSegOp rep, BuilderOps rep,
 CanBeAliased (Op rep), LocalScope rep m, BodyDec rep ~ (),
 ExpDec rep ~ (), LetDec rep ~ Type) =>
DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
orig_acc ([Stm SOACS] -> DistNestT rep m (DistAcc rep))
-> (Stms SOACS -> [Stm SOACS])
-> Stms SOACS
-> DistNestT rep m (DistAcc rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList
  where
    onStms :: DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc [] = DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc
    onStms DistAcc rep
acc (Let Pat
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Stream w arrs Sequential accs lam)) : [Stm SOACS]
stms) = do
      Scope SOACS
types <- (Scope rep -> Scope SOACS) -> DistNestT rep m (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope rep -> Scope SOACS
forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs
      Stms SOACS
stream_stms <-
        ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Stms SOACS)
-> DistNestT rep m ((), Stms SOACS) -> DistNestT rep m (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuilderT SOACS (DistNestT rep m) ()
-> Scope SOACS -> DistNestT rep m ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat (Rep (BuilderT SOACS (DistNestT rep m)))
-> SubExp
-> [SubExp]
-> LambdaT (Rep (BuilderT SOACS (DistNestT rep m)))
-> [VName]
-> BuilderT SOACS (DistNestT rep m) ()
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (Rep m)
-> SubExp -> [SubExp] -> LambdaT (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (Rep (BuilderT SOACS (DistNestT rep m)))
Pat
pat SubExp
w [SubExp]
accs LambdaT (Rep (BuilderT SOACS (DistNestT rep m)))
Lambda
lam [VName]
arrs) Scope SOACS
types
      Stms SOACS
stream_stms' <-
        ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> Scope SOACS -> DistNestT rep m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (SimpleOps SOACS
-> Scope SOACS
-> Stms SOACS
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep -> Scope rep -> Stms rep -> m (Stms rep)
copyPropagateInStms SimpleOps SOACS
simpleSOACS Scope SOACS
types Stms SOACS
stream_stms) Scope SOACS
types
      DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc ([Stm SOACS] -> DistNestT rep m (DistAcc rep))
-> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList ((Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs) Stms SOACS
stream_stms') [Stm SOACS] -> [Stm SOACS] -> [Stm SOACS]
forall a. [a] -> [a] -> [a]
++ [Stm SOACS]
stms
    onStms DistAcc rep
acc (Stm SOACS
stm : [Stm SOACS]
stms) =
      -- It is important that stm is in scope if 'maybeDistributeStm'
      -- wants to distribute, even if this causes the slightly silly
      -- situation that stm is in scope of itself.
      Stm SOACS
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
Stm SOACS -> DistNestT rep m a -> DistNestT rep m a
withStm Stm SOACS
stm (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
stm (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc [Stm SOACS]
stms

onInnerMap :: Monad m => MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap :: MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap MapLoop
loop DistAcc rep
acc = do
  MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
f <- (DistEnv rep m
 -> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT
     rep m (MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *).
DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distOnInnerMap
  MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
f MapLoop
loop DistAcc rep
acc

onTopLevelStms :: Monad m => Stms SOACS -> DistNestT rep m ()
onTopLevelStms :: Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms = do
  Stms SOACS -> DistNestT rep m (Stms rep)
f <- (DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep))
-> DistNestT rep m (Stms SOACS -> DistNestT rep m (Stms rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
forall rep (m :: * -> *).
DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
distOnTopLevelStms
  Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ())
-> DistNestT rep m (Stms rep) -> DistNestT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> DistNestT rep m (Stms rep)
f Stms SOACS
stms

maybeDistributeStm ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  Stm SOACS ->
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
maybeDistributeStm :: Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
stm DistAcc rep
acc
  | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) =
    Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm (Let Pat
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) DistAcc rep
acc
  | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
    DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc (Stms SOACS -> DistNestT rep m (DistAcc rep))
-> (Stms SOACS -> Stms SOACS)
-> Stms SOACS
-> DistNestT rep m (DistAcc rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
      (Stms SOACS -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (Stms SOACS) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Builder SOACS () -> DistNestT rep m (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Pat (Rep (BuilderT SOACS (State VNameSource)))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pat (Rep m) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (Rep (BuilderT SOACS (State VNameSource)))
Pat
pat Op SOACS
SOAC (Rep (BuilderT SOACS (State VNameSource)))
soac)
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
pat StmAux (ExpDec SOACS)
_ (Op (Screma w arrs form))) DistAcc rep
acc
  | Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form =
    -- Only distribute inside the map if we can distribute everything
    -- following the map.
    DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc DistNestT rep m (Maybe (DistAcc rep))
-> (Maybe (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Maybe (DistAcc rep)
Nothing -> Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
      Just DistAcc rep
acc' -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
Monad m =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap (Pat -> StmAux () -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pat
pat (Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) SubExp
w Lambda
lam [VName]
arrs) DistAcc rep
acc'
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
merge form :: LoopForm SOACS
form@ForLoop {} Body SOACS
body)) DistAcc rep
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` PatT Type -> Names
forall a. FreeIn a => a -> Names
freeIn PatT Type
Pat
pat) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat,
    Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body =
    DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
        | -- XXX: We cannot distribute if this loop depends on
          -- certificates bound within the loop nest (well, we could,
          -- but interchange would not be valid).  This is not a
          -- fundamental restriction, but an artifact of our
          -- certificate representation, which we should probably
          -- rethink.
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
            (LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> StmAux () -> Names
forall a. FreeIn a => a -> Names
freeIn StmAux ()
StmAux (ExpDec SOACS)
aux)
              Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
          Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatT Type
Pat
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            Scope SOACS
types <- (Scope rep -> Scope SOACS) -> DistNestT rep m (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope rep -> Scope SOACS
forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs

            -- Simplification is key to hoisting out statements that
            -- were variant to the loop, but invariant to the outer maps
            -- (which are now innermost).
            Stms SOACS
stms <-
              (ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> Scope SOACS -> DistNestT rep m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) (ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
 -> DistNestT rep m (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> DistNestT rep m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
                Stms SOACS -> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms (Stms SOACS
 -> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> SeqLoop -> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest' ([Int]
-> Pat
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pat
pat [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body)
            Stms SOACS -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms
            DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc'
      Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
        Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
pat StmAux (ExpDec SOACS)
_ (If SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfDec (BranchType SOACS)
ret)) DistAcc rep
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` PatT Type -> Names
forall a. FreeIn a => a -> Names
freeIn PatT Type
Pat
pat) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
pat,
    Body SOACS -> Bool
bodyContainsParallelism Body SOACS
tbranch Bool -> Bool -> Bool
|| Body SOACS -> Bool
bodyContainsParallelism Body SOACS
fbranch
      Bool -> Bool -> Bool
|| Bool -> Bool
not ((TypeBase ExtShape NoUniqueness -> Bool)
-> [TypeBase ExtShape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase ExtShape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (IfDec (TypeBase ExtShape NoUniqueness)
-> [TypeBase ExtShape NoUniqueness]
forall rt. IfDec rt -> [rt]
ifReturns IfDec (TypeBase ExtShape NoUniqueness)
IfDec (BranchType SOACS)
ret)) =
    DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
            (SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> IfDec (TypeBase ExtShape NoUniqueness) -> Names
forall a. FreeIn a => a -> Names
freeIn IfDec (TypeBase ExtShape NoUniqueness)
IfDec (BranchType SOACS)
ret) Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
          Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatT Type
Pat
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
            Scope SOACS
types <- (Scope rep -> Scope SOACS) -> DistNestT rep m (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope rep -> Scope SOACS
forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs
            let branch :: Branch
branch = [Int]
-> Pat
-> SubExp
-> Body SOACS
-> Body SOACS
-> IfDec (BranchType SOACS)
-> Branch
Branch [Int]
perm Pat
pat SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfDec (BranchType SOACS)
ret
            Stms SOACS
stms <-
              (ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> Scope SOACS -> DistNestT rep m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) (ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
 -> DistNestT rep m (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> DistNestT rep m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
                Stms SOACS -> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms (Stms SOACS
 -> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> Branch -> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stms SOACS)
interchangeBranch KernelNest
nest' Branch
branch
            Stms SOACS -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms
            DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc'
      Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
        Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
pat StmAux (ExpDec SOACS)
_ (WithAcc [WithAccInput SOACS]
inputs Lambda
lam)) DistAcc rep
acc
  | Lambda -> Bool
lambdaContainsParallelism Lambda
lam =
    DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
            [Type] -> Names
forall a. FreeIn a => a -> Names
freeIn (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_accs (Lambda -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda
lam))
              Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
          Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatT Type
Pat
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            Scope SOACS
types <- (Scope rep -> Scope SOACS) -> DistNestT rep m (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope rep -> Scope SOACS
forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs
            PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
            let withacc :: WithAccStm
withacc = [Int] -> Pat -> [WithAccInput SOACS] -> Lambda -> WithAccStm
WithAccStm [Int]
perm Pat
pat [WithAccInput SOACS]
inputs Lambda
lam
            Stms SOACS
stms <-
              (ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> Scope SOACS -> DistNestT rep m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) (ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
 -> DistNestT rep m (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> DistNestT rep m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
                Stms SOACS -> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms (Stms SOACS
 -> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS))
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> WithAccStm
-> ReaderT (Scope SOACS) (DistNestT rep m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> WithAccStm -> m (Stms SOACS)
interchangeWithAcc KernelNest
nest' WithAccStm
withacc
            Stms SOACS -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms
            DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc'
      Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
        Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
  where
    num_accs :: Int
num_accs = [WithAccInput SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs
maybeDistributeStm (Let Pat
pat StmAux (ExpDec SOACS)
aux (Op (Screma w arrs form))) DistAcc rep
acc
  | Just [Reduce Commutativity
comm Lambda
lam [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    Just BuilderT SOACS (DistNestT rep m) ()
m <- Pat
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BuilderT SOACS (DistNestT rep m) ())
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat
pat SubExp
w Commutativity
comm Lambda
lam ([(SubExp, VName)] -> Maybe (BuilderT SOACS (DistNestT rep m) ()))
-> [(SubExp, VName)] -> Maybe (BuilderT SOACS (DistNestT rep m) ())
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
    Scope SOACS
types <- (Scope rep -> Scope SOACS) -> DistNestT rep m (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope rep -> Scope SOACS
forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs
    (()
_, Stms SOACS
stms) <- BuilderT SOACS (DistNestT rep m) ()
-> Scope SOACS -> DistNestT rep m ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (StmAux ()
-> BuilderT SOACS (DistNestT rep m) ()
-> BuilderT SOACS (DistNestT rep m) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux BuilderT SOACS (DistNestT rep m) ()
m) Scope SOACS
types
    DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc Stms SOACS
stms

-- Parallelise segmented scatters.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Scatter w ivs lam as))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
      | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatT Type
Pat
pat Result
res ->
        Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
          KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
          Lambda rep
lam' <- Lambda -> DistNestT rep m (Lambda rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda -> DistNestT rep m (Lambda rep)
soacsLambda Lambda
lam
          PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
          Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ())
-> DistNestT rep m (Stms rep) -> DistNestT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> PatT Type
-> Certs
-> SubExp
-> Lambda rep
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT rep m (Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> PatT Type
-> Certs
-> SubExp
-> Lambda rep
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT rep m (Stms rep)
segmentedScatterKernel KernelNest
nest' [Int]
perm PatT Type
Pat
pat Certs
cs SubExp
w Lambda rep
lam' [VName]
ivs [(Shape, Int, VName)]
as
          DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
      Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- Parallelise segmented Hist.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Hist w as ops lam))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
      | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatT Type
Pat
pat Result
res ->
        Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
          Lambda rep
lam' <- Lambda -> DistNestT rep m (Lambda rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda -> DistNestT rep m (Lambda rep)
soacsLambda Lambda
lam
          KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
          PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
          Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ())
-> DistNestT rep m (Stms rep) -> DistNestT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda rep
-> [VName]
-> DistNestT rep m (Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda rep
-> [VName]
-> DistNestT rep m (Stms rep)
segmentedHistKernel KernelNest
nest' [Int]
perm Certs
cs SubExp
w [HistOp SOACS]
ops Lambda rep
lam' [VName]
as
          DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
      Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- Parallelise Index slices if the result is going to be returned
-- directly from the kernel.  This is because we would otherwise have
-- to sequentialise writing the result, which may be costly.
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pat [PatElemT (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Index VName
arr Slice SubExp
slice))) DistAcc rep
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice,
    VName -> SubExp
Var (PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElemT (LetDec SOACS)
pe) SubExp -> [SubExp] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ((PatT Type, Result) -> Result
forall a b. (a, b) -> b
snd (Targets -> (PatT Type, Result)
innerTarget (DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc))) =
    DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms rep
kernels, Result
_res, KernelNest
nest, DistAcc rep
acc') ->
        Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
          PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
          Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ())
-> DistNestT rep m (Stms rep) -> DistNestT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
segmentedGatherKernel KernelNest
nest (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice
          DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc'
      Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
        Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- If the scan can be distributed by itself, we will turn it into a
-- segmented scan.
--
-- If the scan cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w arrs form))) DistAcc rep
acc
  | Just ([Scan SOACS]
scans, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
    Scan Lambda
lam [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans =
    DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
        | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatT Type
Pat
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            Lambda rep
map_lam' <- Lambda -> DistNestT rep m (Lambda rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda -> DistNestT rep m (Lambda rep)
soacsLambda Lambda
map_lam
            Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$
              KernelNest
-> [Int]
-> Certs
-> SubExp
-> Lambda
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Lambda
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel KernelNest
nest' [Int]
perm Certs
cs SubExp
w Lambda
lam Lambda rep
map_lam' [SubExp]
nes [VName]
arrs
                DistNestT rep m (Maybe (Stms rep))
-> (Maybe (Stms rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot Certs
forall a. Monoid a => a
mempty Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
      Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
        Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- If the map function of the reduction contains parallelism we split
-- it, so that the parallelism can be exploited.
maybeDistributeStm (Let Pat
pat StmAux (ExpDec SOACS)
aux (Op (Screma w arrs form))) DistAcc rep
acc
  | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
    Lambda -> Bool
lambdaContainsParallelism Lambda
map_lam = do
    (Stm SOACS
mapstm, Stm SOACS
redstm) <-
      Pat
-> (SubExp, [Reduce SOACS], Lambda, [VName])
-> DistNestT rep m (Stm SOACS, Stm SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat rep
-> (SubExp, [Reduce rep], LambdaT rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pat
pat (SubExp
w, [Reduce SOACS]
reds, Lambda
map_lam, [VName]
arrs)
    DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc (Stms SOACS -> DistNestT rep m (DistAcc rep))
-> Stms SOACS -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
mapstm {stmAux :: StmAux (ExpDec SOACS)
stmAux = StmAux (ExpDec SOACS)
aux} Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
redstm
-- if the reduction can be distributed by itself, we will turn it into a
-- segmented reduce.
--
-- If the reduction cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w arrs form))) DistAcc rep
acc
  | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
    Reduce Commutativity
comm Lambda
lam [SubExp]
nes <- [Reduce SOACS] -> Reduce SOACS
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce SOACS]
reds =
    DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
        | Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatT Type
Pat
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest

            Lambda rep
lam' <- Lambda -> DistNestT rep m (Lambda rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda -> DistNestT rep m (Lambda rep)
soacsLambda Lambda
lam
            Lambda rep
map_lam' <- Lambda -> DistNestT rep m (Lambda rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda -> DistNestT rep m (Lambda rep)
soacsLambda Lambda
map_lam

            let comm' :: Commutativity
comm'
                  | Lambda -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda
lam = Commutativity
Commutative
                  | Bool
otherwise = Commutativity
comm

            KernelNest
-> [Int]
-> Certs
-> SubExp
-> Commutativity
-> Lambda rep
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Commutativity
-> Lambda rep
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel KernelNest
nest' [Int]
perm Certs
cs SubExp
w Commutativity
comm' Lambda rep
lam' Lambda rep
map_lam' [SubExp]
nes [VName]
arrs
              DistNestT rep m (Maybe (Stms rep))
-> (Maybe (Stms rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot Certs
forall a. Monoid a => a
mempty Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
      Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
        Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm (Let Pat
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma w arrs form))) DistAcc rep
acc = do
  -- This Screma is too complicated for us to immediately do
  -- anything, so split it up and try again.
  Scope SOACS
scope <- (Scope rep -> Scope SOACS) -> DistNestT rep m (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope rep -> Scope SOACS
forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs
  DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc (Stms SOACS -> DistNestT rep m (DistAcc rep))
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> DistNestT rep m (DistAcc rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs) (Stms SOACS -> Stms SOACS)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
    (((), Stms SOACS) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m ((), Stms SOACS)
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS (DistNestT rep m) ()
-> Scope SOACS -> DistNestT rep m ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat (Rep (BuilderT SOACS (DistNestT rep m)))
-> SubExp
-> ScremaForm (Rep (BuilderT SOACS (DistNestT rep m)))
-> [VName]
-> BuilderT SOACS (DistNestT rep m) ()
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (Rep m) -> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat (Rep (BuilderT SOACS (DistNestT rep m)))
Pat
pat SubExp
w ScremaForm (Rep (BuilderT SOACS (DistNestT rep m)))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
maybeDistributeStm (Let Pat
pat StmAux (ExpDec SOACS)
aux (BasicOp (Replicate (Shape (SubExp
d : [SubExp]
ds)) SubExp
v))) DistAcc rep
acc
  | [Type
t] <- PatT Type -> [Type]
forall dec. Typed dec => PatT dec -> [Type]
patTypes PatT Type
Pat
pat = do
    VName
tmp <- [Char] -> DistNestT rep m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tmp"
    let rowt :: Type
rowt = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
t
        newstm :: Stm SOACS
newstm = Pat -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> ExpT rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
d [] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda
lam
        tmpstm :: Stm SOACS
tmpstm =
          Pat -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat [VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
tmp Type
rowt]) StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds) SubExp
v
        lam :: Lambda
lam =
          Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
            { lambdaReturnType :: [Type]
lambdaReturnType = [Type
rowt],
              lambdaParams :: [LParam SOACS]
lambdaParams = [],
              lambdaBody :: Body SOACS
lambdaBody = Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
tmpstm) [VName -> SubExpRes
varRes VName
tmp]
            }
    Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
newstm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
_ StmAux (ExpDec SOACS)
aux (BasicOp (Copy VName
stm_arr))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr ((KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
 -> DistNestT rep m (DistAcc rep))
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatT Type
outerpat VName
arr ->
    Stms rep -> DistNestT rep m (Stms rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms rep -> DistNestT rep m (Stms rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Stm rep -> Stms rep) -> Stm rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatT Type
Pat rep
outerpat StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
-- Opaques are applied to the full array, because otherwise they can
-- drastically inhibit parallelisation in some cases.
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pat [PatElemT (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Opaque OpaqueOp
_ (Var VName
stm_arr)))) DistAcc rep
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> Type
forall t. Typed t => t -> Type
typeOf PatElemT Type
PatElemT (LetDec SOACS)
pe =
    DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr ((KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
 -> DistNestT rep m (DistAcc rep))
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatT Type
outerpat VName
arr ->
      Stms rep -> DistNestT rep m (Stms rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms rep -> DistNestT rep m (Stms rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Stm rep -> Stms rep) -> Stm rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatT Type
Pat rep
outerpat StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rearrange [Int]
perm VName
stm_arr))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr ((KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
 -> DistNestT rep m (DistAcc rep))
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatT Type
outerpat VName
arr -> do
    let r :: Int
r = [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        perm' :: [Int]
perm' = [Int
0 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) [Int]
perm
    -- We need to add a copy, because the original map nest
    -- will have produced an array without aliases, and so must we.
    VName
arr' <- [Char] -> DistNestT rep m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> DistNestT rep m VName)
-> [Char] -> DistNestT rep m VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
arr
    Type
arr_t <- VName -> DistNestT rep m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
    Stms rep -> DistNestT rep m (Stms rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms rep -> DistNestT rep m (Stms rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$
      [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList
        [ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat [VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
arr' Type
arr_t]) StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr,
          Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatT Type
Pat rep
outerpat StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm' VName
arr'
        ]
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
_ StmAux (ExpDec SOACS)
aux (BasicOp (Reshape ShapeChange SubExp
reshape VName
stm_arr))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr ((KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
 -> DistNestT rep m (DistAcc rep))
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatT Type
outerpat VName
arr -> do
    let reshape' :: ShapeChange SubExp
reshape' =
          (SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (KernelNest -> [SubExp]
kernelNestWidths KernelNest
nest)
            ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
reshape)
    Stms rep -> DistNestT rep m (Stms rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms rep -> DistNestT rep m (Stms rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Stm rep -> Stms rep) -> Stm rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatT Type
Pat rep
outerpat StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
reshape' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rotate [SubExp]
rots VName
stm_arr))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr ((KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
 -> DistNestT rep m (DistAcc rep))
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatT Type
outerpat VName
arr -> do
    let rots' :: [SubExp]
rots' = (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExp -> SubExp
forall a b. a -> b -> a
const (SubExp -> SubExp -> SubExp) -> SubExp -> SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (KernelNest -> [SubExp]
kernelNestWidths KernelNest
nest) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
rots
    Stms rep -> DistNestT rep m (Stms rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms rep -> DistNestT rep m (Stms rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Stm rep -> Stms rep) -> Stm rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatT Type
Pat rep
outerpat StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
pat StmAux (ExpDec SOACS)
aux (BasicOp (Update Safety
_ VName
arr Slice SubExp
slice (Var VName
v)))) DistAcc rep
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice =
    DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
        | (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT Type -> [VName]) -> PatT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pat
forall rep. Stm rep -> Pat rep
stmPat Stm SOACS
stm),
          Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatT Type
Pat
pat Result
res -> do
          PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
            Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm
              (Stms rep -> DistNestT rep m ())
-> DistNestT rep m (Stms rep) -> DistNestT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DistNestT rep m (Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DistNestT rep m (Stms rep)
segmentedUpdateKernel KernelNest
nest' [Int]
perm (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice VName
v
            DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc'
      Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ -> Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat
_ StmAux (ExpDec SOACS)
aux (BasicOp (Concat Int
d VName
x [VName]
xs SubExp
w))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
_, KernelNest
nest, DistAcc rep
acc') ->
      Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$
        KernelNest -> DistNestT rep m (Maybe (Stms rep))
segmentedConcat KernelNest
nest
          DistNestT rep m (Maybe (Stms rep))
-> (Maybe (Stms rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
      Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
  where
    segmentedConcat :: KernelNest -> DistNestT rep m (Maybe (Stms rep))
segmentedConcat KernelNest
nest =
      KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int
0] Names
forall a. Monoid a => a
mempty Names
forall a. Monoid a => a
mempty [] (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs) ((PatT Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> [SubExp]
  -> [VName]
  -> BuilderT rep m ())
 -> DistNestT rep m (Maybe (Stms rep)))
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
forall a b. (a -> b) -> a -> b
$
        \PatT Type
pat [(VName, SubExp)]
_ [KernelInput]
_ [SubExp]
_ (VName
x' : [VName]
xs') ->
          let d' :: Int
d' = Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
           in Stm (Rep (BuilderT rep m)) -> BuilderT rep m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (BuilderT rep m)) -> BuilderT rep m ())
-> Stm (Rep (BuilderT rep m)) -> BuilderT rep m ()
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatT Type
Pat rep
pat StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
d' VName
x' [VName]
xs' SubExp
w
maybeDistributeStm Stm SOACS
stm DistAcc rep
acc =
  Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc

distributeSingleUnaryStm ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  Stm SOACS ->
  VName ->
  (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep)) ->
  DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm :: DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep)
f =
  DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
      | (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT Type -> [VName]) -> PatT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pat
forall rep. Stm rep -> Pat rep
stmPat Stm SOACS
stm),
        (LoopNesting
outer, [LoopNesting]
_) <- KernelNest
nest,
        [(Param Type
arr_p, VName
arr)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        KernelNest -> Names
boundInKernelNest KernelNest
nest Names -> Names -> Names
`namesIntersection` Stm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Stm SOACS
stm
          Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Names
oneName (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
arr_p),
        VName -> KernelNest -> Bool
perfectlyMapped VName
arr KernelNest
nest -> do
        PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
        let outerpat :: PatT Type
outerpat = LoopNesting -> PatT Type
loopNestingPat (LoopNesting -> PatT Type) -> LoopNesting -> PatT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
        Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
          Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ())
-> DistNestT rep m (Stms rep) -> DistNestT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest -> PatT Type -> VName -> DistNestT rep m (Stms rep)
f KernelNest
nest PatT Type
outerpat VName
arr
          DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ -> Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
  where
    perfectlyMapped :: VName -> KernelNest -> Bool
perfectlyMapped VName
arr (LoopNesting
outer, [LoopNesting]
nest)
      | [(Param Type
p, VName
arr')] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        VName
arr VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr' =
        case [LoopNesting]
nest of
          [] -> Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
stm_arr
          LoopNesting
x : [LoopNesting]
xs -> VName -> KernelNest -> Bool
perfectlyMapped (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (LoopNesting
x, [LoopNesting]
xs)
      | Bool
otherwise =
        Bool
False

distribute ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
distribute :: DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute DistAcc rep
acc =
  DistAcc rep -> Maybe (DistAcc rep) -> DistAcc rep
forall a. a -> Maybe a -> a
fromMaybe DistAcc rep
acc (Maybe (DistAcc rep) -> DistAcc rep)
-> DistNestT rep m (Maybe (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc

mkSegLevel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel :: DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel = do
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- (DistEnv rep m
 -> [SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep m (SegOpLevel rep))
-> DistNestT
     rep
     m
     ([SubExp]
      -> [Char]
      -> ThreadRecommendation
      -> BuilderT rep m (SegOpLevel rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m
-> [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
forall rep (m :: * -> *). DistEnv rep m -> MkSegLevel rep m
distSegLevel
  MkSegLevel rep (DistNestT rep m)
-> DistNestT rep m (MkSegLevel rep (DistNestT rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (MkSegLevel rep (DistNestT rep m)
 -> DistNestT rep m (MkSegLevel rep (DistNestT rep m)))
-> MkSegLevel rep (DistNestT rep m)
-> DistNestT rep m (MkSegLevel rep (DistNestT rep m))
forall a b. (a -> b) -> a -> b
$ \[SubExp]
w [Char]
desc ThreadRecommendation
r -> do
    (SegOpLevel rep
lvl, Stms rep
stms) <- DistNestT rep m (SegOpLevel rep, Stms rep)
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep, Stms rep)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT rep m (SegOpLevel rep, Stms rep)
 -> BuilderT rep (DistNestT rep m) (SegOpLevel rep, Stms rep))
-> DistNestT rep m (SegOpLevel rep, Stms rep)
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ m (SegOpLevel rep, Stms rep)
-> DistNestT rep m (SegOpLevel rep, Stms rep)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (m (SegOpLevel rep, Stms rep)
 -> DistNestT rep m (SegOpLevel rep, Stms rep))
-> m (SegOpLevel rep, Stms rep)
-> DistNestT rep m (SegOpLevel rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ BuilderT rep m (SegOpLevel rep) -> m (SegOpLevel rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT rep m (SegOpLevel rep) -> m (SegOpLevel rep, Stms rep))
-> BuilderT rep m (SegOpLevel rep) -> m (SegOpLevel rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl [SubExp]
w [Char]
desc ThreadRecommendation
r
    Stms (Rep (BuilderT rep (DistNestT rep m)))
-> BuilderT rep (DistNestT rep m) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (BuilderT rep (DistNestT rep m)))
stms
    SegOpLevel rep -> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
forall (m :: * -> *) a. Monad m => a -> m a
return SegOpLevel rep
lvl

distributeIfPossible ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible :: DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc = do
  Nestings
nest <- (DistEnv rep m -> Nestings) -> DistNestT rep m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- DistNestT
  rep
  m
  ([SubExp]
   -> [Char]
   -> ThreadRecommendation
   -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
  ([SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
-> Nestings
-> Targets
-> Stms rep
-> DistNestT rep m (Maybe (Targets, Stms rep))
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m,
 MonadLogger m) =>
MkSegLevel rep m
-> Nestings -> Targets -> Stms rep -> m (Maybe (Targets, Stms rep))
tryDistribute [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl Nestings
nest (DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc) (DistAcc rep -> Stms rep
forall rep. DistAcc rep -> Stms rep
distStms DistAcc rep
acc) DistNestT rep m (Maybe (Targets, Stms rep))
-> (Maybe (Targets, Stms rep)
    -> DistNestT rep m (Maybe (DistAcc rep)))
-> DistNestT rep m (Maybe (DistAcc rep))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms rep)
Nothing -> Maybe (DistAcc rep) -> DistNestT rep m (Maybe (DistAcc rep))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (DistAcc rep)
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms rep
kernel) -> do
      Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm Stms rep
kernel
      Maybe (DistAcc rep) -> DistNestT rep m (Maybe (DistAcc rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (DistAcc rep) -> DistNestT rep m (Maybe (DistAcc rep)))
-> Maybe (DistAcc rep) -> DistNestT rep m (Maybe (DistAcc rep))
forall a b. (a -> b) -> a -> b
$
        DistAcc rep -> Maybe (DistAcc rep)
forall a. a -> Maybe a
Just
          DistAcc :: forall rep. Targets -> Stms rep -> DistAcc rep
DistAcc
            { distTargets :: Targets
distTargets = Targets
targets,
              distStms :: Stms rep
distStms = Stms rep
forall a. Monoid a => a
mempty
            }

distributeSingleStm ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  Stm SOACS ->
  DistNestT
    rep
    m
    ( Maybe
        ( PostStms rep,
          Result,
          KernelNest,
          DistAcc rep
        )
    )
distributeSingleStm :: DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm = do
  Nestings
nest <- (DistEnv rep m -> Nestings) -> DistNestT rep m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- DistNestT
  rep
  m
  ([SubExp]
   -> [Char]
   -> ThreadRecommendation
   -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
  ([SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
-> Nestings
-> Targets
-> Stms rep
-> DistNestT rep m (Maybe (Targets, Stms rep))
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m,
 MonadLogger m) =>
MkSegLevel rep m
-> Nestings -> Targets -> Stms rep -> m (Maybe (Targets, Stms rep))
tryDistribute [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl Nestings
nest (DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc) (DistAcc rep -> Stms rep
forall rep. DistAcc rep -> Stms rep
distStms DistAcc rep
acc) DistNestT rep m (Maybe (Targets, Stms rep))
-> (Maybe (Targets, Stms rep)
    -> DistNestT
         rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)))
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms rep)
Nothing -> Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms rep
distributed_stms) ->
      Nestings
-> Targets
-> Stm SOACS
-> DistNestT rep m (Maybe (Result, Targets, KernelNest))
forall (m :: * -> *) t rep.
(MonadFreshNames m, HasScope t m, ASTRep rep) =>
Nestings
-> Targets -> Stm rep -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm SOACS
stm DistNestT rep m (Maybe (Result, Targets, KernelNest))
-> (Maybe (Result, Targets, KernelNest)
    -> DistNestT
         rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)))
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe (Result, Targets, KernelNest)
Nothing -> Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
forall a. Maybe a
Nothing
        Just (Result
res, Targets
targets', KernelNest
new_kernel_nest) ->
          Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
 -> DistNestT
      rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)))
-> Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall a b. (a -> b) -> a -> b
$
            (PostStms rep, Result, KernelNest, DistAcc rep)
-> Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
forall a. a -> Maybe a
Just
              ( Stms rep -> PostStms rep
forall rep. Stms rep -> PostStms rep
PostStms Stms rep
distributed_stms,
                Result
res,
                KernelNest
new_kernel_nest,
                DistAcc :: forall rep. Targets -> Stms rep -> DistAcc rep
DistAcc
                  { distTargets :: Targets
distTargets = Targets
targets',
                    distStms :: Stms rep
distStms = Stms rep
forall a. Monoid a => a
mempty
                  }
              )

segmentedScatterKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  PatT Type ->
  Certs ->
  SubExp ->
  Lambda rep ->
  [VName] ->
  [(Shape, Int, VName)] ->
  DistNestT rep m (Stms rep)
segmentedScatterKernel :: KernelNest
-> [Int]
-> PatT Type
-> Certs
-> SubExp
-> Lambda rep
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT rep m (Stms rep)
segmentedScatterKernel KernelNest
nest [Int]
perm PatT Type
scatter_pat Certs
cs SubExp
scatter_w Lambda rep
lam [VName]
ivs [(Shape, Int, VName)]
dests = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a scatter is not a reduction or
  -- scan.
  --
  -- First, pretend that the scatter is also part of the nesting.  The
  -- KernelNest we produce here is technically not sensible, but it's
  -- good enough for flatKernel to work.
  let nesting :: LoopNesting
nesting =
        PatT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatT Type
scatter_pat (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty ()) SubExp
scatter_w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [LParam rep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam) [VName]
ivs
      nest' :: KernelNest
nest' =
        (PatT Type, Result) -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (PatT Type
scatter_pat, BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result) -> BodyT rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam) LoopNesting
nesting KernelNest
nest
  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT rep m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest'

  let ([Shape]
as_ws, [Int]
as_ns, [VName]
as) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
      indexes :: [Int]
indexes = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws

  -- The input/output arrays ('as') _must_ correspond to some kernel
  -- input, or else the original nested scatter would have been
  -- ill-typed.  Find them.
  [KernelInput]
as_inps <- (VName -> DistNestT rep m KernelInput)
-> [VName] -> DistNestT rep m [KernelInput]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([KernelInput] -> VName -> DistNestT rep m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
kernel_inps) [VName]
as

  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- DistNestT
  rep
  m
  ([SubExp]
   -> [Char]
   -> ThreadRecommendation
   -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel

  let rts :: [Type]
rts =
        ([Type] -> [Type]) -> [[Type]] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
1) ([[Type]] -> [Type]) -> [[Type]] -> [Type]
forall a b. (a -> b) -> a -> b
$
          [Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$
            Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
indexes) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam
      (Result
is, Result
vs) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
indexes) (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result) -> BodyT rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam

  (Result
is', Stms rep
k_body_stms) <- Builder rep Result -> DistNestT rep m (Result, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep Result -> DistNestT rep m (Result, Stms rep))
-> Builder rep Result -> DistNestT rep m (Result, Stms rep)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (BuilderT rep (State VNameSource)))
 -> BuilderT rep (State VNameSource) ())
-> Stms (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Stms rep) -> BodyT rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
    Result -> Builder rep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
is

  let k_body :: KernelBody rep
k_body =
        [(Shape, Int, KernelInput)]
-> Result -> [(Shape, KernelInput, [(Result, SubExpRes)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ([Shape] -> [Int] -> [KernelInput] -> [(Shape, Int, KernelInput)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
as_ws [Int]
as_ns [KernelInput]
as_inps) (Result
is' Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
vs)
          [(Shape, KernelInput, [(Result, SubExpRes)])]
-> ([(Shape, KernelInput, [(Result, SubExpRes)])]
    -> [KernelResult])
-> [KernelResult]
forall a b. a -> (a -> b) -> b
& ((Shape, KernelInput, [(Result, SubExpRes)]) -> KernelResult)
-> [(Shape, KernelInput, [(Result, SubExpRes)])] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)]
-> (Shape, KernelInput, [(Result, SubExpRes)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace)
          [KernelResult]
-> ([KernelResult] -> KernelBody rep) -> KernelBody rep
forall a b. a -> (a -> b) -> b
& BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
k_body_stms
      -- Remove unused kernel inputs, since some of these might
      -- reference the array we are scattering into.
      kernel_inps' :: [KernelInput]
kernel_inps' =
        (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` KernelBody rep -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody rep
k_body) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps

  (SegOp (SegOpLevel rep) rep
k, Stms rep
k_stms) <- ([SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> DistNestT rep m (SegOp (SegOpLevel rep) rep, Stms rep)
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps' [Type]
rts KernelBody rep
k_body

  (Stm rep -> DistNestT rep m (Stm rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm rep -> DistNestT rep m (Stm rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm (Stms rep -> DistNestT rep m (Stms rep))
-> (BuilderT rep (State VNameSource) ()
    -> DistNestT rep m (Stms rep))
-> BuilderT rep (State VNameSource) ()
-> DistNestT rep m (Stms rep)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BuilderT rep (State VNameSource) () -> DistNestT rep m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (BuilderT rep (State VNameSource) () -> DistNestT rep m (Stms rep))
-> BuilderT rep (State VNameSource) ()
-> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (BuilderT rep (State VNameSource)))
k_stms

    let pat :: PatT Type
pat =
          [PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT Type] -> PatT Type)
-> ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type]
-> PatT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> PatT Type) -> [PatElemT Type] -> PatT Type
forall a b. (a -> b) -> a -> b
$
            PatT Type -> [PatElemT Type]
forall dec. PatT dec -> [PatElemT dec]
patElems (PatT Type -> [PatElemT Type]) -> PatT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatT Type
loopNestingPat (LoopNesting -> PatT Type) -> LoopNesting -> PatT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pat (Rep (BuilderT rep (State VNameSource)))
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind PatT Type
Pat (Rep (BuilderT rep (State VNameSource)))
pat (Exp (Rep (BuilderT rep (State VNameSource)))
 -> BuilderT rep (State VNameSource) ())
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op rep -> ExpT rep
forall rep. Op rep -> ExpT rep
Op (Op rep -> ExpT rep) -> Op rep -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
k
  where
    findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
      m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
    bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested scatter encountered."

    inPlaceReturn :: [(VName, SubExp)]
-> (Shape, KernelInput, [(Result, SubExpRes)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace (Shape
aw, KernelInput
inp, [(Result, SubExpRes)]
is_vs) =
      Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
        ( ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
            Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
        )
        ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
ws [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
aw))
        (KernelInput -> VName
kernelInputArray KernelInput
inp)
        [ ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix ([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v)
          | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs
        ]
      where
        ([VName]
gtids, [SubExp]
ws) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
ispace

segmentedUpdateKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  VName ->
  Slice SubExp ->
  VName ->
  DistNestT rep m (Stms rep)
segmentedUpdateKernel :: KernelNest
-> [Int]
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DistNestT rep m (Stms rep)
segmentedUpdateKernel KernelNest
nest [Int]
perm Certs
cs VName
arr Slice SubExp
slice VName
v = do
  ([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT rep m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let slice_dims :: [SubExp]
slice_dims = Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  [VName]
slice_gtids <- Int -> DistNestT rep m VName -> DistNestT rep m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice_dims) ([Char] -> DistNestT rep m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")

  let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids [SubExp]
slice_dims

  ((Type
res_t, KernelResult
res), Stms rep
kstms) <- Builder rep (Type, KernelResult)
-> DistNestT rep m ((Type, KernelResult), Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Type, KernelResult)
 -> DistNestT rep m ((Type, KernelResult), Stms rep))
-> Builder rep (Type, KernelResult)
-> DistNestT rep m ((Type, KernelResult), Stms rep)
forall a b. (a -> b) -> a -> b
$ do
    -- Compute indexes into full array.
    SubExp
v' <-
      Certs
-> BuilderT rep (State VNameSource) SubExp
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT rep (State VNameSource) SubExp
 -> BuilderT rep (State VNameSource) SubExp)
-> BuilderT rep (State VNameSource) SubExp
-> BuilderT rep (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
        [Char]
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"v" (Exp (Rep (BuilderT rep (State VNameSource)))
 -> BuilderT rep (State VNameSource) SubExp)
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
    [SubExp]
slice_is <-
      (TPrimExp Int64 VName -> BuilderT rep (State VNameSource) SubExp)
-> [TPrimExp Int64 VName]
-> BuilderT rep (State VNameSource) [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ([Char]
-> TPrimExp Int64 VName -> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") ([TPrimExp Int64 VName]
 -> BuilderT rep (State VNameSource) [SubExp])
-> [TPrimExp Int64 VName]
-> BuilderT rep (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
        Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (VName -> SubExp) -> VName -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids

    let write_is :: [SubExp]
write_is = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) [(VName, SubExp)]
base_ispace [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
slice_is
        arr' :: VName
arr' =
          VName -> (KernelInput -> VName) -> Maybe KernelInput -> VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> VName
forall a. HasCallStack => [Char] -> a
error [Char]
"incorrectly typed Update") KernelInput -> VName
kernelInputArray (Maybe KernelInput -> VName) -> Maybe KernelInput -> VName
forall a b. (a -> b) -> a -> b
$
            (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
    Type
arr_t <- VName -> BuilderT rep (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr'
    Type
v_t <- SubExp -> BuilderT rep (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
    (Type, KernelResult) -> Builder rep (Type, KernelResult)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( Type
v_t,
        Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
forall a. Monoid a => a
mempty (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t) VName
arr' [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
write_is, SubExp
v')]
      )

  -- Remove unused kernel inputs, since some of these might
  -- reference the array we are scattering into.
  let kernel_inps' :: [KernelInput]
kernel_inps' =
        (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` (Stms rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelResult -> Names
forall a. FreeIn a => a -> Names
freeIn KernelResult
res)) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps

  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- DistNestT
  rep
  m
  ([SubExp]
   -> [Char]
   -> ThreadRecommendation
   -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
  (SegOp (SegOpLevel rep) rep
k, Stms rep
prestms) <-
    ([SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> DistNestT rep m (SegOp (SegOpLevel rep) rep, Stms rep)
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps' [Type
res_t] (KernelBody rep
 -> DistNestT rep m (SegOp (SegOpLevel rep) rep, Stms rep))
-> KernelBody rep
-> DistNestT rep m (SegOp (SegOpLevel rep) rep, Stms rep)
forall a b. (a -> b) -> a -> b
$
      BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
kstms [KernelResult
res]

  (Stm rep -> DistNestT rep m (Stm rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm rep -> DistNestT rep m (Stm rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm (Stms rep -> DistNestT rep m (Stms rep))
-> (Builder rep () -> DistNestT rep m (Stms rep))
-> Builder rep ()
-> DistNestT rep m (Stms rep)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Builder rep () -> DistNestT rep m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> DistNestT rep m (Stms rep))
-> Builder rep () -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (BuilderT rep (State VNameSource)))
prestms

    let pat :: PatT Type
pat =
          [PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT Type] -> PatT Type)
-> ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type]
-> PatT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> PatT Type) -> [PatElemT Type] -> PatT Type
forall a b. (a -> b) -> a -> b
$
            PatT Type -> [PatElemT Type]
forall dec. PatT dec -> [PatElemT dec]
patElems (PatT Type -> [PatElemT Type]) -> PatT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatT Type
loopNestingPat (LoopNesting -> PatT Type) -> LoopNesting -> PatT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pat (Rep (BuilderT rep (State VNameSource)))
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind PatT Type
Pat (Rep (BuilderT rep (State VNameSource)))
pat (Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ())
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall a b. (a -> b) -> a -> b
$ Op rep -> ExpT rep
forall rep. Op rep -> ExpT rep
Op (Op rep -> ExpT rep) -> Op rep -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
k

segmentedGatherKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  Certs ->
  VName ->
  Slice SubExp ->
  DistNestT rep m (Stms rep)
segmentedGatherKernel :: KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
segmentedGatherKernel KernelNest
nest Certs
cs VName
arr Slice SubExp
slice = do
  let slice_dims :: [SubExp]
slice_dims = Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  [VName]
slice_gtids <- Int -> DistNestT rep m VName -> DistNestT rep m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice_dims) ([Char] -> DistNestT rep m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")

  ([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT rep m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids [SubExp]
slice_dims

  ((Type
res_t, KernelResult
res), Stms rep
kstms) <- Builder rep (Type, KernelResult)
-> DistNestT rep m ((Type, KernelResult), Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Type, KernelResult)
 -> DistNestT rep m ((Type, KernelResult), Stms rep))
-> Builder rep (Type, KernelResult)
-> DistNestT rep m ((Type, KernelResult), Stms rep)
forall a b. (a -> b) -> a -> b
$ do
    -- Compute indexes into full array.
    Slice SubExp
slice'' <-
      Slice (TPrimExp Int64 VName)
-> BuilderT rep (State VNameSource) (Slice SubExp)
forall (m :: * -> *).
MonadBuilder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName)
 -> BuilderT rep (State VNameSource) (Slice SubExp))
-> (Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName)
-> BuilderT rep (State VNameSource) (Slice SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
slice) (Slice (TPrimExp Int64 VName)
 -> BuilderT rep (State VNameSource) (Slice SubExp))
-> Slice (TPrimExp Int64 VName)
-> BuilderT rep (State VNameSource) (Slice SubExp)
forall a b. (a -> b) -> a -> b
$
        Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice (Slice SubExp -> Slice (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
    SubExp
v' <- Certs
-> BuilderT rep (State VNameSource) SubExp
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT rep (State VNameSource) SubExp
 -> BuilderT rep (State VNameSource) SubExp)
-> BuilderT rep (State VNameSource) SubExp
-> BuilderT rep (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ [Char]
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"v" (Exp (Rep (BuilderT rep (State VNameSource)))
 -> BuilderT rep (State VNameSource) SubExp)
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice''
    Type
v_t <- SubExp -> BuilderT rep (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
    (Type, KernelResult) -> Builder rep (Type, KernelResult)
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
v_t, ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
forall a. Monoid a => a
mempty SubExp
v')

  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- DistNestT
  rep
  m
  ([SubExp]
   -> [Char]
   -> ThreadRecommendation
   -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
  (SegOp (SegOpLevel rep) rep
k, Stms rep
prestms) <-
    ([SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep (DistNestT rep m) (SegOpLevel rep))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> DistNestT rep m (SegOp (SegOpLevel rep) rep, Stms rep)
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] (KernelBody rep
 -> DistNestT rep m (SegOp (SegOpLevel rep) rep, Stms rep))
-> KernelBody rep
-> DistNestT rep m (SegOp (SegOpLevel rep) rep, Stms rep)
forall a b. (a -> b) -> a -> b
$
      BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
kstms [KernelResult
res]

  (Stm rep -> DistNestT rep m (Stm rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm rep -> DistNestT rep m (Stm rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm (Stms rep -> DistNestT rep m (Stms rep))
-> (Builder rep () -> DistNestT rep m (Stms rep))
-> Builder rep ()
-> DistNestT rep m (Stms rep)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Builder rep () -> DistNestT rep m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> DistNestT rep m (Stms rep))
-> Builder rep () -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (BuilderT rep (State VNameSource)))
prestms

    let pat :: PatT Type
pat = [PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT Type] -> PatT Type) -> [PatElemT Type] -> PatT Type
forall a b. (a -> b) -> a -> b
$ PatT Type -> [PatElemT Type]
forall dec. PatT dec -> [PatElemT dec]
patElems (PatT Type -> [PatElemT Type]) -> PatT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatT Type
loopNestingPat (LoopNesting -> PatT Type) -> LoopNesting -> PatT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pat (Rep (BuilderT rep (State VNameSource)))
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind PatT Type
Pat (Rep (BuilderT rep (State VNameSource)))
pat (Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ())
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall a b. (a -> b) -> a -> b
$ Op rep -> ExpT rep
forall rep. Op rep -> ExpT rep
Op (Op rep -> ExpT rep) -> Op rep -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
k

segmentedHistKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda rep ->
  [VName] ->
  DistNestT rep m (Stms rep)
segmentedHistKernel :: KernelNest
-> [Int]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda rep
-> [VName]
-> DistNestT rep m (Stms rep)
segmentedHistKernel KernelNest
nest [Int]
perm Certs
cs SubExp
hist_w [HistOp SOACS]
ops Lambda rep
lam [VName]
arrs = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a Hist is not a reduction or
  -- scan.
  ([(VName, SubExp)]
ispace, [KernelInput]
inputs) <- KernelNest -> DistNestT rep m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let orig_pat :: PatT Type
orig_pat =
        [PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT Type] -> PatT Type)
-> ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type]
-> PatT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> PatT Type) -> [PatElemT Type] -> PatT Type
forall a b. (a -> b) -> a -> b
$
          PatT Type -> [PatElemT Type]
forall dec. PatT dec -> [PatElemT dec]
patElems (PatT Type -> [PatElemT Type]) -> PatT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatT Type
loopNestingPat (LoopNesting -> PatT Type) -> LoopNesting -> PatT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

  -- The input/output arrays _must_ correspond to some kernel input,
  -- or else the original nested Hist would have been ill-typed.
  -- Find them.
  [HistOp SOACS]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> DistNestT rep m (HistOp SOACS))
-> DistNestT rep m [HistOp SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> DistNestT rep m (HistOp SOACS))
 -> DistNestT rep m [HistOp SOACS])
-> (HistOp SOACS -> DistNestT rep m (HistOp SOACS))
-> DistNestT rep m [HistOp SOACS]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda
op) ->
    SubExp -> SubExp -> [VName] -> [SubExp] -> Lambda -> HistOp SOACS
forall rep.
SubExp -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
SOACS.HistOp SubExp
num_bins SubExp
rf
      ([VName] -> [SubExp] -> Lambda -> HistOp SOACS)
-> DistNestT rep m [VName]
-> DistNestT rep m ([SubExp] -> Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> DistNestT rep m VName)
-> [VName] -> DistNestT rep m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((KernelInput -> VName)
-> DistNestT rep m KernelInput -> DistNestT rep m VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KernelInput -> VName
kernelInputArray (DistNestT rep m KernelInput -> DistNestT rep m VName)
-> (VName -> DistNestT rep m KernelInput)
-> VName
-> DistNestT rep m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [KernelInput] -> VName -> DistNestT rep m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
inputs) [VName]
dests
      DistNestT rep m ([SubExp] -> Lambda -> HistOp SOACS)
-> DistNestT rep m [SubExp]
-> DistNestT rep m (Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> DistNestT rep m [SubExp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes
      DistNestT rep m (Lambda -> HistOp SOACS)
-> DistNestT rep m Lambda -> DistNestT rep m (HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda -> DistNestT rep m Lambda
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda
op

  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- (DistEnv rep m
 -> [SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep m (SegOpLevel rep))
-> DistNestT
     rep
     m
     ([SubExp]
      -> [Char]
      -> ThreadRecommendation
      -> BuilderT rep m (SegOpLevel rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m
-> [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
forall rep (m :: * -> *). DistEnv rep m -> MkSegLevel rep m
distSegLevel
  Lambda -> Builder rep (Lambda rep)
onLambda <- (DistEnv rep m -> Lambda -> Builder rep (Lambda rep))
-> DistNestT rep m (Lambda -> Builder rep (Lambda rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Lambda -> Builder rep (Lambda rep)
forall rep (m :: * -> *).
DistEnv rep m -> Lambda -> Builder rep (Lambda rep)
distOnSOACSLambda
  let onLambda' :: Lambda -> BuilderT rep m (Lambda rep)
onLambda' = ((Lambda rep, Stms rep) -> Lambda rep)
-> BuilderT rep m (Lambda rep, Stms rep)
-> BuilderT rep m (Lambda rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda rep, Stms rep) -> Lambda rep
forall a b. (a, b) -> a
fst (BuilderT rep m (Lambda rep, Stms rep)
 -> BuilderT rep m (Lambda rep))
-> (Lambda -> BuilderT rep m (Lambda rep, Stms rep))
-> Lambda
-> BuilderT rep m (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder rep (Lambda rep) -> BuilderT rep m (Lambda rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Lambda rep) -> BuilderT rep m (Lambda rep, Stms rep))
-> (Lambda -> Builder rep (Lambda rep))
-> Lambda
-> BuilderT rep m (Lambda rep, Stms rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Builder rep (Lambda rep)
onLambda
  m (Stms rep) -> DistNestT rep m (Stms rep)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (m (Stms rep) -> DistNestT rep m (Stms rep))
-> m (Stms rep) -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$
    BuilderT rep m () -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ (BuilderT rep m () -> m (Stms rep))
-> BuilderT rep m () -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
      -- It is important not to launch unnecessarily many threads for
      -- histograms, because it may mean we unnecessarily need to reduce
      -- subhistograms as well.
      SegOpLevel rep
lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
hist_w SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"seghist" (ThreadRecommendation -> BuilderT rep m (SegOpLevel rep))
-> ThreadRecommendation -> BuilderT rep m (SegOpLevel rep)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      Stms rep -> BuilderT rep m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
        (Stms rep -> BuilderT rep m ())
-> BuilderT rep m (Stms rep) -> BuilderT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Lambda -> BuilderT rep m (Lambda (Rep (BuilderT rep m))))
-> SegOpLevel (Rep (BuilderT rep m))
-> PatT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep (BuilderT rep m))
-> [VName]
-> BuilderT rep m (Stms (Rep (BuilderT rep m)))
forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
(Lambda -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> PatT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda -> BuilderT rep m (Lambda rep)
Lambda -> BuilderT rep m (Lambda (Rep (BuilderT rep m)))
onLambda' SegOpLevel rep
SegOpLevel (Rep (BuilderT rep m))
lvl PatT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certs
cs SubExp
hist_w [HistOp SOACS]
ops' Lambda rep
Lambda (Rep (BuilderT rep m))
lam [VName]
arrs
  where
    findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
      m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
    bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested Hist encountered."

histKernel ::
  (MonadBuilder m, DistRep (Rep m)) =>
  (Lambda SOACS -> m (Lambda (Rep m))) ->
  SegOpLevel (Rep m) ->
  PatT Type ->
  [(VName, SubExp)] ->
  [KernelInput] ->
  Certs ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda (Rep m) ->
  [VName] ->
  m (Stms (Rep m))
histKernel :: (Lambda -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> PatT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda -> m (Lambda (Rep m))
onLambda SegOpLevel (Rep m)
lvl PatT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certs
cs SubExp
hist_w [HistOp SOACS]
ops Lambda (Rep m)
lam [VName]
arrs = BuilderT (Rep m) m () -> m (Stms (Rep m))
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ (BuilderT (Rep m) m () -> m (Stms (Rep m)))
-> BuilderT (Rep m) m () -> m (Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
  [HistOp (Rep m)]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> BuilderT (Rep m) m (HistOp (Rep m)))
-> BuilderT (Rep m) m [HistOp (Rep m)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> BuilderT (Rep m) m (HistOp (Rep m)))
 -> BuilderT (Rep m) m [HistOp (Rep m)])
-> (HistOp SOACS -> BuilderT (Rep m) m (HistOp (Rep m)))
-> BuilderT (Rep m) m [HistOp (Rep m)]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda
op) -> do
    (Lambda
op', [SubExp]
nes', Shape
shape) <- Lambda -> [SubExp] -> BuilderT (Rep m) m (Lambda, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda -> [SubExp] -> m (Lambda, [SubExp], Shape)
determineReduceOp Lambda
op [SubExp]
nes
    Lambda (Rep m)
op'' <- m (Lambda (Rep m)) -> BuilderT (Rep m) m (Lambda (Rep m))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Lambda (Rep m)) -> BuilderT (Rep m) m (Lambda (Rep m)))
-> m (Lambda (Rep m)) -> BuilderT (Rep m) m (Lambda (Rep m))
forall a b. (a -> b) -> a -> b
$ Lambda -> m (Lambda (Rep m))
onLambda Lambda
op'
    HistOp (Rep m) -> BuilderT (Rep m) m (HistOp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp (Rep m) -> BuilderT (Rep m) m (HistOp (Rep m)))
-> HistOp (Rep m) -> BuilderT (Rep m) m (HistOp (Rep m))
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda (Rep m)
-> HistOp (Rep m)
forall rep.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda (Rep m)
op''

  let isDest :: VName -> Bool
isDest = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem ([VName] -> VName -> Bool) -> [VName] -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ (HistOp (Rep m) -> [VName]) -> [HistOp (Rep m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Rep m) -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp (Rep m)]
ops'
      inputs' :: [KernelInput]
inputs' = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (KernelInput -> Bool) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isDest (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputArray) [KernelInput]
inputs

  Certs -> BuilderT (Rep m) m () -> BuilderT (Rep m) m ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT (Rep m) m () -> BuilderT (Rep m) m ())
-> BuilderT (Rep m) m () -> BuilderT (Rep m) m ()
forall a b. (a -> b) -> a -> b
$
    Stms (Rep m) -> BuilderT (Rep m) m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep m) -> BuilderT (Rep m) m ())
-> BuilderT (Rep m) m (Stms (Rep m)) -> BuilderT (Rep m) m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm (Rep m) -> BuilderT (Rep m) m (Stm (Rep m)))
-> Stms (Rep m) -> BuilderT (Rep m) m (Stms (Rep m))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm (Rep m) -> BuilderT (Rep m) m (Stm (Rep m))
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
      (Stms (Rep m) -> BuilderT (Rep m) m (Stms (Rep m)))
-> BuilderT (Rep m) m (Stms (Rep m))
-> BuilderT (Rep m) m (Stms (Rep m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel (Rep m)
-> Pat (Rep m)
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp (Rep m)]
-> Lambda (Rep m)
-> [VName]
-> BuilderT (Rep m) m (Stms (Rep m))
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pat rep
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegOpLevel (Rep m)
lvl PatT Type
Pat (Rep m)
orig_pat SubExp
hist_w [(VName, SubExp)]
ispace [KernelInput]
inputs' [HistOp (Rep m)]
ops' Lambda (Rep m)
lam [VName]
arrs

determineReduceOp ::
  MonadBuilder m =>
  Lambda SOACS ->
  [SubExp] ->
  m (Lambda SOACS, [SubExp], Shape)
determineReduceOp :: Lambda -> [SubExp] -> m (Lambda, [SubExp], Shape)
determineReduceOp Lambda
lam [SubExp]
nes =
  -- FIXME? We are assuming that the accumulator is a replicate, and
  -- we fish out its value in a gross way.
  case (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes of
    Just [VName]
ne_vs' -> do
      let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
      [SubExp]
nes' <- [VName] -> (VName -> m SubExp) -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ne_vs' ((VName -> m SubExp) -> m [SubExp])
-> (VName -> m SubExp) -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
ne_v -> do
        Type
ne_v_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ne_v
        [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"hist_ne" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
ne_v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
ne_v_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
                Int -> DimIndex SubExp -> [DimIndex SubExp]
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (DimIndex SubExp -> [DimIndex SubExp])
-> DimIndex SubExp -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
      (Lambda, [SubExp], Shape) -> m (Lambda, [SubExp], Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam', [SubExp]
nes', Shape
shape)
    Maybe [VName]
Nothing ->
      (Lambda, [SubExp], Shape) -> m (Lambda, [SubExp], Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam, [SubExp]
nes, Shape
forall a. Monoid a => a
mempty)

isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap :: Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
  | [Let (Pat [PatElemT (LetDec SOACS)]
pes) StmAux (ExpDec SOACS)
_ (Op (Screma w arrs form))] <-
      Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam,
    (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult (Lambda -> Body SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam)) [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElemT Type -> SubExp) -> [PatElemT Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
[PatElemT (LetDec SOACS)]
pes,
    Just Lambda
map_lam <- ScremaForm SOACS -> Maybe Lambda
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
    [VName]
arrs [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda
lam) =
    let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
map_lam
     in ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape, Lambda
lam')
  | Bool
otherwise = (Shape
forall a. Monoid a => a
mempty, Lambda
lam)

segmentedScanomapKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  SubExp ->
  Lambda SOACS ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel :: KernelNest
-> [Int]
-> Certs
-> SubExp
-> Lambda
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel KernelNest
nest [Int]
perm Certs
cs SubExp
segment_size Lambda
lam Lambda rep
map_lam [SubExp]
nes [VName]
arrs = do
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- (DistEnv rep m
 -> [SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep m (SegOpLevel rep))
-> DistNestT
     rep
     m
     ([SubExp]
      -> [Char]
      -> ThreadRecommendation
      -> BuilderT rep m (SegOpLevel rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m
-> [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
forall rep (m :: * -> *). DistEnv rep m -> MkSegLevel rep m
distSegLevel
  Lambda -> Builder rep (Lambda rep)
onLambda <- (DistEnv rep m -> Lambda -> Builder rep (Lambda rep))
-> DistNestT rep m (Lambda -> Builder rep (Lambda rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Lambda -> Builder rep (Lambda rep)
forall rep (m :: * -> *).
DistEnv rep m -> Lambda -> Builder rep (Lambda rep)
distOnSOACSLambda
  let onLambda' :: Lambda -> BuilderT rep m (Lambda rep)
onLambda' = ((Lambda rep, Stms rep) -> Lambda rep)
-> BuilderT rep m (Lambda rep, Stms rep)
-> BuilderT rep m (Lambda rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda rep, Stms rep) -> Lambda rep
forall a b. (a, b) -> a
fst (BuilderT rep m (Lambda rep, Stms rep)
 -> BuilderT rep m (Lambda rep))
-> (Lambda -> BuilderT rep m (Lambda rep, Stms rep))
-> Lambda
-> BuilderT rep m (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder rep (Lambda rep) -> BuilderT rep m (Lambda rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Lambda rep) -> BuilderT rep m (Lambda rep, Stms rep))
-> (Lambda -> Builder rep (Lambda rep))
-> Lambda
-> BuilderT rep m (Lambda rep, Stms rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Builder rep (Lambda rep)
onLambda
  KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm (Lambda -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda
lam) (Lambda rep -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda rep
map_lam) [SubExp]
nes [] ((PatT Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> [SubExp]
  -> [VName]
  -> BuilderT rep m ())
 -> DistNestT rep m (Maybe (Stms rep)))
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
forall a b. (a -> b) -> a -> b
$
    \PatT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps [SubExp]
nes' [VName]
_ -> do
      (Lambda
lam', [SubExp]
nes'', Shape
shape) <- Lambda -> [SubExp] -> BuilderT rep m (Lambda, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda -> [SubExp] -> m (Lambda, [SubExp], Shape)
determineReduceOp Lambda
lam [SubExp]
nes'
      Lambda rep
lam'' <- Lambda -> BuilderT rep m (Lambda rep)
onLambda' Lambda
lam'
      let scan_op :: SegBinOp rep
scan_op = Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda rep
lam'' [SubExp]
nes'' Shape
shape
      SegOpLevel rep
lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
segment_size SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segscan" (ThreadRecommendation -> BuilderT rep m (SegOpLevel rep))
-> ThreadRecommendation -> BuilderT rep m (SegOpLevel rep)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      Stms rep -> BuilderT rep m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms rep -> BuilderT rep m ())
-> BuilderT rep m (Stms rep) -> BuilderT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm rep -> BuilderT rep m (Stm rep))
-> Stms rep -> BuilderT rep m (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm rep -> BuilderT rep m (Stm rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
        (Stms rep -> BuilderT rep m (Stms rep))
-> BuilderT rep m (Stms rep) -> BuilderT rep m (Stms rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel rep
-> Pat rep
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT rep m (Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat rep
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel rep
lvl PatT Type
Pat rep
pat Certs
cs SubExp
segment_size [SegBinOp rep
scan_op] Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

regularSegmentedRedomapKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  SubExp ->
  Commutativity ->
  Lambda rep ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel :: KernelNest
-> [Int]
-> Certs
-> SubExp
-> Commutativity
-> Lambda rep
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel KernelNest
nest [Int]
perm Certs
cs SubExp
segment_size Commutativity
comm Lambda rep
lam Lambda rep
map_lam [SubExp]
nes [VName]
arrs = do
  [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- (DistEnv rep m
 -> [SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep m (SegOpLevel rep))
-> DistNestT
     rep
     m
     ([SubExp]
      -> [Char]
      -> ThreadRecommendation
      -> BuilderT rep m (SegOpLevel rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m
-> [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
forall rep (m :: * -> *). DistEnv rep m -> MkSegLevel rep m
distSegLevel
  KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm (Lambda rep -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda rep
lam) (Lambda rep -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda rep
map_lam) [SubExp]
nes [] ((PatT Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> [SubExp]
  -> [VName]
  -> BuilderT rep m ())
 -> DistNestT rep m (Maybe (Stms rep)))
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
forall a b. (a -> b) -> a -> b
$
    \PatT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps [SubExp]
nes' [VName]
_ -> do
      let red_op :: SegBinOp rep
red_op = Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes' Shape
forall a. Monoid a => a
mempty
      SegOpLevel rep
lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
segment_size SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segred" (ThreadRecommendation -> BuilderT rep m (SegOpLevel rep))
-> ThreadRecommendation -> BuilderT rep m (SegOpLevel rep)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      Stms rep -> BuilderT rep m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms rep -> BuilderT rep m ())
-> BuilderT rep m (Stms rep) -> BuilderT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm rep -> BuilderT rep m (Stm rep))
-> Stms rep -> BuilderT rep m (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm rep -> BuilderT rep m (Stm rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
        (Stms rep -> BuilderT rep m (Stms rep))
-> BuilderT rep m (Stms rep) -> BuilderT rep m (Stms rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel rep
-> Pat rep
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT rep m (Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat rep
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel rep
lvl PatT Type
Pat rep
pat Certs
cs SubExp
segment_size [SegBinOp rep
red_op] Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

isSegmentedOp ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Names ->
  Names ->
  [SubExp] ->
  [VName] ->
  ( PatT Type ->
    [(VName, SubExp)] ->
    [KernelInput] ->
    [SubExp] ->
    [VName] ->
    BuilderT rep m ()
  ) ->
  DistNestT rep m (Maybe (Stms rep))
isSegmentedOp :: KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (PatT Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm Names
free_in_op Names
_free_in_fold_op [SubExp]
nes [VName]
arrs PatT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ()
m = MaybeT (DistNestT rep m) (Stms rep)
-> DistNestT rep m (Maybe (Stms rep))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT (DistNestT rep m) (Stms rep)
 -> DistNestT rep m (Maybe (Stms rep)))
-> MaybeT (DistNestT rep m) (Stms rep)
-> DistNestT rep m (Maybe (Stms rep))
forall a b. (a -> b) -> a -> b
$ do
  -- We must verify that array inputs to the operation are inputs to
  -- the outermost loop nesting or free in the loop nest.  Nothing
  -- free in the op may be bound by the nest.  Furthermore, the
  -- neutral elements must be free in the loop nest.
  --
  -- We must summarise any names from free_in_op that are bound in the
  -- nest, and describe how to obtain them given segment indices.

  let bound_by_nest :: Names
bound_by_nest = KernelNest -> Names
boundInKernelNest KernelNest
nest

  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest
-> MaybeT (DistNestT rep m) ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest

  Bool -> MaybeT (DistNestT rep m) () -> MaybeT (DistNestT rep m) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Names
free_in_op Names -> Names -> Bool
`namesIntersect` Names
bound_by_nest) (MaybeT (DistNestT rep m) () -> MaybeT (DistNestT rep m) ())
-> MaybeT (DistNestT rep m) () -> MaybeT (DistNestT rep m) ()
forall a b. (a -> b) -> a -> b
$
    [Char] -> MaybeT (DistNestT rep m) ()
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Non-fold lambda uses nest-bound parameters."

  let indices :: [VName]
indices = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace

      prepareNe :: SubExp -> MaybeT (DistNestT rep m) SubExp
prepareNe (Var VName
v)
        | VName
v VName -> Names -> Bool
`nameIn` Names
bound_by_nest =
          [Char] -> MaybeT (DistNestT rep m) SubExp
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Neutral element bound in nest"
      prepareNe SubExp
ne = SubExp -> MaybeT (DistNestT rep m) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
ne

      prepareArr :: VName -> MaybeT (DistNestT rep m) (BuilderT rep m VName)
prepareArr VName
arr =
        case (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps of
          Just KernelInput
inp
            | KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
indices ->
              BuilderT rep m VName
-> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BuilderT rep m VName
 -> MaybeT (DistNestT rep m) (BuilderT rep m VName))
-> BuilderT rep m VName
-> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall a b. (a -> b) -> a -> b
$ VName -> BuilderT rep m VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> BuilderT rep m VName) -> VName -> BuilderT rep m VName
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
          Maybe KernelInput
Nothing
            | Bool -> Bool
not (VName
arr VName -> Names -> Bool
`nameIn` Names
bound_by_nest) ->
              -- This input is something that is free inside
              -- the loop nesting. We will have to replicate
              -- it.
              BuilderT rep m VName
-> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BuilderT rep m VName
 -> MaybeT (DistNestT rep m) (BuilderT rep m VName))
-> BuilderT rep m VName
-> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall a b. (a -> b) -> a -> b
$
                [Char] -> Exp (Rep (BuilderT rep m)) -> BuilderT rep m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp
                  (VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_repd")
                  (BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
          Maybe KernelInput
_ ->
            [Char] -> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Input not free, perfectly mapped, or outermost."

  [SubExp]
nes' <- (SubExp -> MaybeT (DistNestT rep m) SubExp)
-> [SubExp] -> MaybeT (DistNestT rep m) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> MaybeT (DistNestT rep m) SubExp
prepareNe [SubExp]
nes

  [BuilderT rep m VName]
mk_arrs <- (VName -> MaybeT (DistNestT rep m) (BuilderT rep m VName))
-> [VName] -> MaybeT (DistNestT rep m) [BuilderT rep m VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> MaybeT (DistNestT rep m) (BuilderT rep m VName)
prepareArr [VName]
arrs

  DistNestT rep m (Stms rep) -> MaybeT (DistNestT rep m) (Stms rep)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT rep m (Stms rep) -> MaybeT (DistNestT rep m) (Stms rep))
-> DistNestT rep m (Stms rep)
-> MaybeT (DistNestT rep m) (Stms rep)
forall a b. (a -> b) -> a -> b
$
    m (Stms rep) -> DistNestT rep m (Stms rep)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (m (Stms rep) -> DistNestT rep m (Stms rep))
-> m (Stms rep) -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$
      BuilderT rep m () -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ (BuilderT rep m () -> m (Stms rep))
-> BuilderT rep m () -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
        [VName]
nested_arrs <- [BuilderT rep m VName] -> BuilderT rep m [VName]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [BuilderT rep m VName]
mk_arrs

        let pat :: PatT Type
pat =
              [PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT Type] -> PatT Type)
-> ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type]
-> PatT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> PatT Type) -> [PatElemT Type] -> PatT Type
forall a b. (a -> b) -> a -> b
$
                PatT Type -> [PatElemT Type]
forall dec. PatT dec -> [PatElemT dec]
patElems (PatT Type -> [PatElemT Type]) -> PatT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatT Type
loopNestingPat (LoopNesting -> PatT Type) -> LoopNesting -> PatT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

        PatT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ()
m PatT Type
pat [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [SubExp]
nes' [VName]
nested_arrs

permutationAndMissing :: PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing :: PatT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing (Pat [PatElemT Type]
pes) Result
res = do
  let ([PatElemT Type]
_used, [PatElemT Type]
unused) =
        (PatElemT Type -> Bool)
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> Names -> Bool
`nameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res) (VName -> Bool)
-> (PatElemT Type -> VName) -> PatElemT Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
pes
      res' :: [SubExp]
res' = (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
      res_expanded :: [SubExp]
res_expanded = [SubExp]
res' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (PatElemT Type -> SubExp) -> [PatElemT Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
unused
  [Int]
perm <- (PatElemT Type -> SubExp) -> [PatElemT Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
pes [SubExp] -> [SubExp] -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` [SubExp]
res_expanded
  ([Int], [PatElemT Type]) -> Maybe ([Int], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [PatElemT Type]
unused)

-- Add extra pattern elements to every kernel nesting level.
expandKernelNest ::
  MonadFreshNames m =>
  [PatElemT Type] ->
  KernelNest ->
  m KernelNest
expandKernelNest :: [PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pes (LoopNesting
outer_nest, [LoopNesting]
inner_nests) = do
  let outer_size :: [SubExp]
outer_size =
        LoopNesting -> SubExp
loopNestingWidth LoopNesting
outer_nest SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
        (LoopNesting -> SubExp) -> [LoopNesting] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
      inner_sizes :: [[SubExp]]
inner_sizes = [SubExp] -> [[SubExp]]
forall a. [a] -> [[a]]
tails ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> SubExp) -> [LoopNesting] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
  LoopNesting
outer_nest' <- LoopNesting -> [SubExp] -> m LoopNesting
expandWith LoopNesting
outer_nest [SubExp]
outer_size
  [LoopNesting]
inner_nests' <- (LoopNesting -> [SubExp] -> m LoopNesting)
-> [LoopNesting] -> [[SubExp]] -> m [LoopNesting]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM LoopNesting -> [SubExp] -> m LoopNesting
expandWith [LoopNesting]
inner_nests [[SubExp]]
inner_sizes
  KernelNest -> m KernelNest
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopNesting
outer_nest', [LoopNesting]
inner_nests')
  where
    expandWith :: LoopNesting -> [SubExp] -> m LoopNesting
expandWith LoopNesting
nest [SubExp]
dims = do
      [PatElemT Type]
pes' <- (PatElemT Type -> m (PatElemT Type))
-> [PatElemT Type] -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([SubExp] -> PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) dec.
(MonadFreshNames m, Typed dec) =>
[SubExp] -> PatElemT dec -> m (PatElemT Type)
expandPatElemWith [SubExp]
dims) [PatElemT Type]
pes
      LoopNesting -> m LoopNesting
forall (m :: * -> *) a. Monad m => a -> m a
return
        LoopNesting
nest
          { loopNestingPat :: PatT Type
loopNestingPat =
              [PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT Type] -> PatT Type) -> [PatElemT Type] -> PatT Type
forall a b. (a -> b) -> a -> b
$ PatT Type -> [PatElemT Type]
forall dec. PatT dec -> [PatElemT dec]
patElems (LoopNesting -> PatT Type
loopNestingPat LoopNesting
nest) [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type]
pes'
          }

    expandPatElemWith :: [SubExp] -> PatElemT dec -> m (PatElemT Type)
expandPatElemWith [SubExp]
dims PatElemT dec
pe = do
      VName
name <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> m VName) -> [Char] -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe
      PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return
        PatElemT dec
pe
          { patElemName :: VName
patElemName = VName
name,
            patElemDec :: Type
patElemDec = PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims
          }

kernelOrNot ::
  (MonadFreshNames m, DistRep rep) =>
  Certs ->
  Stm SOACS ->
  DistAcc rep ->
  PostStms rep ->
  DistAcc rep ->
  Maybe (Stms rep) ->
  DistNestT rep m (DistAcc rep)
kernelOrNot :: Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot Certs
cs Stm SOACS
stm DistAcc rep
acc PostStms rep
_ DistAcc rep
_ Maybe (Stms rep)
Nothing =
  Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs Stm SOACS
stm) DistAcc rep
acc
kernelOrNot Certs
cs Stm SOACS
_ DistAcc rep
_ PostStms rep
kernels DistAcc rep
acc' (Just Stms rep
stms) = do
  PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
  Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ()) -> Stms rep -> DistNestT rep m ()
forall a b. (a -> b) -> a -> b
$ (Stm rep -> Stm rep) -> Stms rep -> Stms rep
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs) Stms rep
stms
  DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc rep
acc'

distributeMap ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  MapLoop ->
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
distributeMap :: MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap (MapLoop Pat
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) DistAcc rep
acc =
  DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute
    (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PatT Type
-> StmAux ()
-> SubExp
-> Lambda
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
PatT Type
-> StmAux ()
-> SubExp
-> Lambda
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
mapNesting
      PatT Type
Pat
pat
      StmAux ()
aux
      SubExp
w
      Lambda
lam
      [VName]
arrs
      (DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc' Stms SOACS
lam_stms)
  where
    acc' :: DistAcc rep
acc' =
      DistAcc :: forall rep. Targets -> Stms rep -> DistAcc rep
DistAcc
        { distTargets :: Targets
distTargets =
            (PatT Type, Result) -> Targets -> Targets
pushInnerTarget
              (PatT Type
Pat
pat, Body SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam)
              (Targets -> Targets) -> Targets -> Targets
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc,
          distStms :: Stms rep
distStms = Stms rep
forall a. Monoid a => a
mempty
        }

    lam_stms :: Stms SOACS
lam_stms = Body SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam