{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
module Futhark.Pass.ExtractKernels.DistributeNests
( MapLoop(..)
, mapLoopStm
, bodyContainsParallelism
, lambdaContainsParallelism
, determineReduceOp
, incrementalFlattening
, histKernel
, DistEnv (..)
, DistAcc (..)
, runDistNestT
, DistNestT
, distributeMap
, distribute
, distributeSingleStm
, distributeMapBodyStms
, addStmsToAcc
, addStmToAcc
, permutationAndMissing
, addPostStms
, postStm
, inNesting
)
where
import Control.Arrow (first)
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Writer.Strict
import Control.Monad.Trans.Maybe
import Data.Maybe
import Data.List (find, partition, tails)
import Futhark.Representation.AST
import qualified Futhark.Representation.SOACS as SOACS
import Futhark.Representation.SOACS.SOAC hiding (HistOp, histDest)
import Futhark.Representation.SOACS (SOACS)
import Futhark.Representation.SOACS.Simplify (simpleSOACS)
import Futhark.Representation.SegOp
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.CopyPropagate
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Util
import Futhark.Util.Log
scopeForSOACs :: SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs :: Scope lore -> Scope SOACS
scopeForSOACs = Scope lore -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope
data MapLoop = MapLoop SOACS.Pattern Certificates SubExp SOACS.Lambda [VName]
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm (MapLoop Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs) = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam) [VName]
arrs
data DistEnv lore m =
DistEnv { DistEnv lore m -> Nestings
distNest :: Nestings
, DistEnv lore m -> Scope lore
distScope :: Scope lore
, DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
distOnTopLevelStms :: Stms SOACS -> DistNestT lore m (Stms lore)
, DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distOnInnerMap :: MapLoop -> DistAcc lore
-> DistNestT lore m (DistAcc lore)
, DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
distOnSOACSStms :: Stm SOACS -> Binder lore (Stms lore)
, DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda :: Lambda SOACS -> Binder lore (Lambda lore)
, DistEnv lore m -> MkSegLevel lore m
distSegLevel :: MkSegLevel lore m
}
data DistAcc lore =
DistAcc { DistAcc lore -> Targets
distTargets :: Targets
, DistAcc lore -> Stms lore
distStms :: Stms lore
}
data DistRes lore =
DistRes { DistRes lore -> PostStms lore
accPostStms :: PostStms lore
, DistRes lore -> Log
accLog :: Log
}
instance Semigroup (DistRes lore) where
DistRes PostStms lore
ks1 Log
log1 <> :: DistRes lore -> DistRes lore -> DistRes lore
<> DistRes PostStms lore
ks2 Log
log2 =
PostStms lore -> Log -> DistRes lore
forall lore. PostStms lore -> Log -> DistRes lore
DistRes (PostStms lore
ks1 PostStms lore -> PostStms lore -> PostStms lore
forall a. Semigroup a => a -> a -> a
<> PostStms lore
ks2) (Log
log1 Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log2)
instance Monoid (DistRes lore) where
mempty :: DistRes lore
mempty = PostStms lore -> Log -> DistRes lore
forall lore. PostStms lore -> Log -> DistRes lore
DistRes PostStms lore
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty
newtype PostStms lore = PostStms { PostStms lore -> Stms lore
unPostStms :: Stms lore }
instance Semigroup (PostStms lore) where
PostStms Stms lore
xs <> :: PostStms lore -> PostStms lore -> PostStms lore
<> PostStms Stms lore
ys = Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms (Stms lore -> PostStms lore) -> Stms lore -> PostStms lore
forall a b. (a -> b) -> a -> b
$ Stms lore
ys Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
xs
instance Monoid (PostStms lore) where
mempty :: PostStms lore
mempty = Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
forall a. Monoid a => a
mempty
typeEnvFromDistAcc :: DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc :: DistAcc lore -> Scope lore
typeEnvFromDistAcc = PatternT Type -> Scope lore
forall lore attr.
(LetAttr lore ~ attr) =>
PatternT attr -> Scope lore
scopeOfPattern (PatternT Type -> Scope lore)
-> (DistAcc lore -> PatternT Type) -> DistAcc lore -> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternT Type, Result) -> PatternT Type
forall a b. (a, b) -> a
fst ((PatternT Type, Result) -> PatternT Type)
-> (DistAcc lore -> (PatternT Type, Result))
-> DistAcc lore
-> PatternT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (PatternT Type, Result)
outerTarget (Targets -> (PatternT Type, Result))
-> (DistAcc lore -> Targets)
-> DistAcc lore
-> (PatternT Type, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets
addStmsToAcc :: Stms lore -> DistAcc lore -> DistAcc lore
addStmsToAcc :: Stms lore -> DistAcc lore -> DistAcc lore
addStmsToAcc Stms lore
stms DistAcc lore
acc =
DistAcc lore
acc { distStms :: Stms lore
distStms = Stms lore
stms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc }
addStmToAcc :: (MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore
-> DistNestT lore m (DistAcc lore)
addStmToAcc :: Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc = do
Stm SOACS -> Binder lore (Stms lore)
onSoacs <- (DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore))
-> DistNestT lore m (Stm SOACS -> Binder lore (Stms lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
forall lore (m :: * -> *).
DistEnv lore m -> Stm SOACS -> Binder lore (Stms lore)
distOnSOACSStms
(Stms lore
stm', Stms lore
_) <- Binder lore (Stms lore) -> DistNestT lore m (Stms lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Stms lore)
-> DistNestT lore m (Stms lore, Stms lore))
-> Binder lore (Stms lore)
-> DistNestT lore m (Stms lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Binder lore (Stms lore)
onSoacs Stm SOACS
stm
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc { distStms :: Stms lore
distStms = Stms lore
stm' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc }
soacsLambda :: (MonadFreshNames m, DistLore lore) =>
Lambda SOACS -> DistNestT lore m (Lambda lore)
soacsLambda :: Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam = do
Lambda -> Binder lore (Lambda lore)
onLambda <- (DistEnv lore m -> Lambda -> Binder lore (Lambda lore))
-> DistNestT lore m (Lambda -> Binder lore (Lambda lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
forall lore (m :: * -> *).
DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda
(Lambda lore, Stms lore) -> Lambda lore
forall a b. (a, b) -> a
fst ((Lambda lore, Stms lore) -> Lambda lore)
-> DistNestT lore m (Lambda lore, Stms lore)
-> DistNestT lore m (Lambda lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Binder lore (Lambda lore)
-> DistNestT lore m (Lambda lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Lambda -> Binder lore (Lambda lore)
onLambda Lambda
lam)
newtype DistNestT lore m a =
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
deriving (a -> DistNestT lore m b -> DistNestT lore m a
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
(forall a b. (a -> b) -> DistNestT lore m a -> DistNestT lore m b)
-> (forall a b. a -> DistNestT lore m b -> DistNestT lore m a)
-> Functor (DistNestT lore m)
forall a b. a -> DistNestT lore m b -> DistNestT lore m a
forall a b. (a -> b) -> DistNestT lore m a -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Functor m =>
a -> DistNestT lore m b -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DistNestT lore m b -> DistNestT lore m a
$c<$ :: forall lore (m :: * -> *) a b.
Functor m =>
a -> DistNestT lore m b -> DistNestT lore m a
fmap :: (a -> b) -> DistNestT lore m a -> DistNestT lore m b
$cfmap :: forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT lore m a -> DistNestT lore m b
Functor, Functor (DistNestT lore m)
a -> DistNestT lore m a
Functor (DistNestT lore m)
-> (forall a. a -> DistNestT lore m a)
-> (forall a b.
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b)
-> (forall a b c.
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c)
-> (forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b)
-> (forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a)
-> Applicative (DistNestT lore m)
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall a. a -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a b.
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
forall a b c.
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall lore (m :: * -> *).
Applicative m =>
Functor (DistNestT lore m)
forall lore (m :: * -> *) a.
Applicative m =>
a -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
forall lore (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
$c<* :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m a
*> :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
$c*> :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
liftA2 :: (a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
$cliftA2 :: forall lore (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m c
<*> :: DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
$c<*> :: forall lore (m :: * -> *) a b.
Applicative m =>
DistNestT lore m (a -> b)
-> DistNestT lore m a -> DistNestT lore m b
pure :: a -> DistNestT lore m a
$cpure :: forall lore (m :: * -> *) a.
Applicative m =>
a -> DistNestT lore m a
$cp1Applicative :: forall lore (m :: * -> *).
Applicative m =>
Functor (DistNestT lore m)
Applicative, Applicative (DistNestT lore m)
a -> DistNestT lore m a
Applicative (DistNestT lore m)
-> (forall a b.
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b)
-> (forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b)
-> (forall a. a -> DistNestT lore m a)
-> Monad (DistNestT lore m)
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a. a -> DistNestT lore m a
forall a b.
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall a b.
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
forall lore (m :: * -> *).
Monad m =>
Applicative (DistNestT lore m)
forall lore (m :: * -> *) a. Monad m => a -> DistNestT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DistNestT lore m a
$creturn :: forall lore (m :: * -> *) a. Monad m => a -> DistNestT lore m a
>> :: DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
$c>> :: forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a -> DistNestT lore m b -> DistNestT lore m b
>>= :: DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
$c>>= :: forall lore (m :: * -> *) a b.
Monad m =>
DistNestT lore m a
-> (a -> DistNestT lore m b) -> DistNestT lore m b
$cp1Monad :: forall lore (m :: * -> *).
Monad m =>
Applicative (DistNestT lore m)
Monad,
MonadReader (DistEnv lore m),
MonadWriter (DistRes lore))
instance MonadTrans (DistNestT lore) where
lift :: m a -> DistNestT lore m a
lift = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a)
-> (m a -> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
-> m a
-> DistNestT lore m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT (DistRes lore) m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes lore) m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a)
-> (m a -> WriterT (DistRes lore) m a)
-> m a
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> WriterT (DistRes lore) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
instance MonadFreshNames m => MonadFreshNames (DistNestT lore m) where
getNameSource :: DistNestT lore m VNameSource
getNameSource = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource)
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
-> DistNestT lore m VNameSource
forall a b. (a -> b) -> a -> b
$ WriterT (DistRes lore) m VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) VNameSource
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift WriterT (DistRes lore) m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
putNameSource :: VNameSource -> DistNestT lore m ()
putNameSource = ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
-> DistNestT lore m ()
forall lore (m :: * -> *) a.
ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) a
-> DistNestT lore m a
DistNestT (ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
-> DistNestT lore m ())
-> (VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ())
-> VNameSource
-> DistNestT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT (DistRes lore) m ()
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes lore) m ()
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ())
-> (VNameSource -> WriterT (DistRes lore) m ())
-> VNameSource
-> ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> WriterT (DistRes lore) m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource
instance (Monad m, Attributes lore) => HasScope lore (DistNestT lore m) where
askScope :: DistNestT lore m (Scope lore)
askScope = (DistEnv lore m -> Scope lore) -> DistNestT lore m (Scope lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope
instance (Monad m, Attributes lore) => LocalScope lore (DistNestT lore m) where
localScope :: Scope lore -> DistNestT lore m a -> DistNestT lore m a
localScope Scope lore
types = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
DistEnv lore m
env { distScope :: Scope lore
distScope = Scope lore
types Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env }
instance Monad m => MonadLogger (DistNestT lore m) where
addLog :: Log -> DistNestT lore m ()
addLog Log
msgs = DistRes lore -> DistNestT lore m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell DistRes lore
forall a. Monoid a => a
mempty { accLog :: Log
accLog = Log
msgs }
runDistNestT :: (MonadLogger m, DistLore lore) =>
DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT :: DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT DistEnv lore m
env (DistNestT ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
m) = do
(DistAcc lore
acc, DistRes lore
res) <- WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore))
-> WriterT (DistRes lore) m (DistAcc lore)
-> m (DistAcc lore, DistRes lore)
forall a b. (a -> b) -> a -> b
$ ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
-> DistEnv lore m -> WriterT (DistRes lore) m (DistAcc lore)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (DistEnv lore m) (WriterT (DistRes lore) m) (DistAcc lore)
m DistEnv lore m
env
Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog (Log -> m ()) -> Log -> m ()
forall a b. (a -> b) -> a -> b
$ DistRes lore -> Log
forall lore. DistRes lore -> Log
accLog DistRes lore
res
Stms lore -> m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> m (Stms lore)) -> Stms lore -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$
PostStms lore -> Stms lore
forall lore. PostStms lore -> Stms lore
unPostStms (DistRes lore -> PostStms lore
forall lore. DistRes lore -> PostStms lore
accPostStms DistRes lore
res) Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>
(PatternT Type, Result) -> Stms lore
identityStms (Targets -> (PatternT Type, Result)
outerTarget (Targets -> (PatternT Type, Result))
-> Targets -> (PatternT Type, Result)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc)
where outermost :: LoopNesting
outermost = Nesting -> LoopNesting
nestingLoop (Nesting -> LoopNesting) -> Nesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
case DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env of (Nesting
nest, []) -> Nesting
nest
(Nesting
_, Nesting
nest : [Nesting]
_) -> Nesting
nest
params_to_arrs :: [(VName, VName)]
params_to_arrs = ((Param Type, VName) -> (VName, VName))
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param Type -> VName) -> (Param Type, VName) -> (VName, VName)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Param Type -> VName
forall attr. Param attr -> VName
paramName) ([(Param Type, VName)] -> [(VName, VName)])
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$
LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outermost
identityStms :: (PatternT Type, Result) -> Stms lore
identityStms (PatternT Type
rem_pat, Result
res) =
[Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm lore] -> Stms lore) -> [Stm lore] -> Stms lore
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> SubExp -> Stm lore)
-> [PatElemT Type] -> Result -> [Stm lore]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT Type -> SubExp -> Stm lore
identityStm (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT Type
rem_pat) Result
res
identityStm :: PatElemT Type -> SubExp -> Stm lore
identityStm PatElemT Type
pe (Var VName
v)
| Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
identityStm PatElemT Type
pe SubExp
se =
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [LoopNesting -> SubExp
loopNestingWidth LoopNesting
outermost]) SubExp
se
addPostStms :: Monad m => PostStms lore -> DistNestT lore m ()
addPostStms :: PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
ks = DistRes lore -> DistNestT lore m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (DistRes lore -> DistNestT lore m ())
-> DistRes lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ DistRes Any
forall a. Monoid a => a
mempty { accPostStms :: PostStms lore
accPostStms = PostStms lore
ks }
postStm :: Monad m => Stms lore -> DistNestT lore m ()
postStm :: Stms lore -> DistNestT lore m ()
postStm Stms lore
stms = PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms (PostStms lore -> DistNestT lore m ())
-> PostStms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
stms
withStm :: (Monad m, DistLore lore) =>
Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm :: Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm Stm SOACS
stm = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
DistEnv lore m
env { distScope :: Scope lore
distScope =
Scope SOACS -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Stm SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm SOACS
stm) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
, distNest :: Nestings
distNest =
Names -> Nestings -> Nestings
letBindInInnerNesting Names
provided (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$
DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env
}
where provided :: Names
provided = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm
mapNesting :: (Monad m, DistLore lore) =>
PatternT Type -> Certificates -> SubExp -> Lambda SOACS -> [VName]
-> DistNestT lore m a
-> DistNestT lore m a
mapNesting :: PatternT Type
-> Certificates
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m a
-> DistNestT lore m a
mapNesting PatternT Type
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
DistEnv lore m
env { distNest :: Nestings
distNest = Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nest (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$ DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest DistEnv lore m
env
, distScope :: Scope lore
distScope = Scope SOACS -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Lambda -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda
lam) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
}
where nest :: Nesting
nest = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty (LoopNesting -> Nesting) -> LoopNesting -> Nesting
forall a b. (a -> b) -> a -> b
$
PatternT Type
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$
[Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
inNesting :: (Monad m, DistLore lore) =>
KernelNest -> DistNestT lore m a -> DistNestT lore m a
inNesting :: KernelNest -> DistNestT lore m a -> DistNestT lore m a
inNesting (LoopNesting
outer, [LoopNesting]
nests) = (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a -> DistNestT lore m a)
-> (DistEnv lore m -> DistEnv lore m)
-> DistNestT lore m a
-> DistNestT lore m a
forall a b. (a -> b) -> a -> b
$ \DistEnv lore m
env ->
DistEnv lore m
env { distNest :: Nestings
distNest = (Nesting
inner, [Nesting]
nests')
, distScope :: Scope lore
distScope = (LoopNesting -> Scope lore) -> [LoopNesting] -> Scope lore
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap LoopNesting -> Scope lore
forall lore. DistLore lore => LoopNesting -> Scope lore
scopeOfLoopNesting (LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests) Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> DistEnv lore m -> Scope lore
forall lore (m :: * -> *). DistEnv lore m -> Scope lore
distScope DistEnv lore m
env
}
where (Nesting
inner, [Nesting]
nests') =
case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
[] -> (LoopNesting -> Nesting
asNesting LoopNesting
outer, [])
(LoopNesting
inner' : [LoopNesting]
ns) -> (LoopNesting -> Nesting
asNesting LoopNesting
inner', (LoopNesting -> Nesting) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Nesting
asNesting ([LoopNesting] -> [Nesting]) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> a -> b
$ LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
ns)
asNesting :: LoopNesting -> Nesting
asNesting = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism = (Stm SOACS -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp SOACS -> Bool
forall lore. ExpT lore -> Bool
isMap (Exp SOACS -> Bool)
-> (Stm SOACS -> Exp SOACS) -> Stm SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Exp SOACS
forall lore. Stm lore -> Exp lore
stmExp) (Stms SOACS -> Bool)
-> (Body SOACS -> Stms SOACS) -> Body SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms
where isMap :: ExpT lore -> Bool
isMap Op{} = Bool
True
isMap ExpT lore
_ = Bool
False
lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism :: Lambda -> Bool
lambdaContainsParallelism = Body SOACS -> Bool
bodyContainsParallelism (Body SOACS -> Bool) -> (Lambda -> Body SOACS) -> Lambda -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody
incrementalFlattening :: Bool
incrementalFlattening :: Bool
incrementalFlattening = Maybe String -> Bool
forall a. Maybe a -> Bool
isJust (Maybe String -> Bool) -> Maybe String -> Bool
forall a b. (a -> b) -> a -> b
$ String -> [(String, String)] -> Maybe String
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup String
"FUTHARK_INCREMENTAL_FLATTENING" [(String, String)]
unixEnvironment
leavingNesting :: MonadFreshNames m =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting :: DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting DistAcc lore
acc =
case Targets -> Maybe ((PatternT Type, Result), Targets)
popInnerTarget (Targets -> Maybe ((PatternT Type, Result), Targets))
-> Targets -> Maybe ((PatternT Type, Result), Targets)
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc of
Maybe ((PatternT Type, Result), Targets)
Nothing ->
String -> DistNestT lore m (DistAcc lore)
forall a. HasCallStack => String -> a
error String
"The kernel targets list is unexpectedly small"
Just ((PatternT Type, Result)
_, Targets
newtargets) ->
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc { distTargets :: Targets
distTargets = Targets
newtargets }
distributeMapBodyStms :: (MonadFreshNames m, DistLore lore) => DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms :: DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
orig_acc = DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *).
(MonadFreshNames m, Bindable lore, HasSegOp lore, BinderOps lore,
ExpAttr lore ~ (), LetAttr lore ~ Type, BodyAttr lore ~ ()) =>
DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
orig_acc ([Stm SOACS] -> DistNestT lore m (DistAcc lore))
-> (Stms SOACS -> [Stm SOACS])
-> Stms SOACS
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList
where
onStms :: DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc [] = DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc
onStms DistAcc lore
acc (Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Stream w (Sequential accs) lam arrs)):[Stm SOACS]
stms) = do
Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
Stms SOACS
stream_stms <-
((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Stms SOACS)
-> DistNestT lore m ((), Stms SOACS)
-> DistNestT lore m (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT lore m)))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS (DistNestT lore m)))
-> [VName]
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS (DistNestT lore m)))
Pattern
pat SubExp
w Result
accs LambdaT (Lore (BinderT SOACS (DistNestT lore m)))
Lambda
lam [VName]
arrs) Scope SOACS
types
(SymbolTable (Wise SOACS)
_, Stms SOACS
stream_stms') <-
ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
-> Scope SOACS
-> DistNestT lore m (SymbolTable (Wise SOACS), Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (SimpleOps SOACS
-> Scope SOACS
-> Stms SOACS
-> ReaderT
(Scope SOACS)
(DistNestT lore m)
(SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
copyPropagateInStms SimpleOps SOACS
simpleSOACS Scope SOACS
types Stms SOACS
stream_stms) Scope SOACS
types
DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc ([Stm SOACS] -> DistNestT lore m (DistAcc lore))
-> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList ((Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms SOACS
stream_stms') [Stm SOACS] -> [Stm SOACS] -> [Stm SOACS]
forall a. [a] -> [a] -> [a]
++ [Stm SOACS]
stms
onStms DistAcc lore
acc (Stm SOACS
stm:[Stm SOACS]
stms) =
Stm SOACS
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore a.
(Monad m, DistLore lore) =>
Stm SOACS -> DistNestT lore m a -> DistNestT lore m a
withStm Stm SOACS
stm (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
stm (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc lore -> [Stm SOACS] -> DistNestT lore m (DistAcc lore)
onStms DistAcc lore
acc [Stm SOACS]
stms
onInnerMap :: Monad m => MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap :: MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap MapLoop
loop DistAcc lore
acc = do
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
f <- (DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT
lore m (MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *).
DistEnv lore m
-> MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distOnInnerMap
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
f MapLoop
loop DistAcc lore
acc
onTopLevelStms :: Monad m => Stms SOACS -> DistNestT lore m ()
onTopLevelStms :: Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms = do
Stms SOACS -> DistNestT lore m (Stms lore)
f <- (DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore))
-> DistNestT lore m (Stms SOACS -> DistNestT lore m (Stms lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
forall lore (m :: * -> *).
DistEnv lore m -> Stms SOACS -> DistNestT lore m (Stms lore)
distOnTopLevelStms
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> DistNestT lore m (Stms lore)
f Stms SOACS
stms
maybeDistributeStm :: (MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore
-> DistNestT lore m (DistAcc lore)
maybeDistributeStm :: Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat StmAux (ExpAttr SOACS)
_ (Op (Screma w form arrs))) DistAcc lore
acc
| Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form =
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc DistNestT lore m (Maybe (DistAcc lore))
-> (Maybe (DistAcc lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (DistAcc lore)
Nothing -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
Just DistAcc lore
acc' -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
Monad m =>
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
onInnerMap (Pattern -> Certificates -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern
pat (Stm SOACS -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm SOACS
bnd) SubExp
w Lambda
lam [VName]
arrs) DistAcc lore
acc'
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat StmAux (ExpAttr SOACS)
_ (DoLoop [] [(FParam SOACS, SubExp)]
val form :: LoopForm SOACS
form@ForLoop{} Body SOACS
body)) DistAcc lore
acc
| [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements PatternT Type
Pattern
pat), Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
Stms SOACS
bnds <- ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> Scope SOACS -> DistNestT lore m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT
(KernelNest
-> SeqLoop -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest' ([Int]
-> Pattern
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pattern
pat [(FParam SOACS, SubExp)]
val LoopForm SOACS
form Body SOACS
body)) Scope SOACS
types
Stms SOACS -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
bnds
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpAttr SOACS)
_ (If SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfAttr (BranchType SOACS)
ret)) DistAcc lore
acc
| [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements PatternT Type
Pattern
pat),
Body SOACS -> Bool
bodyContainsParallelism Body SOACS
tbranch Bool -> Bool -> Bool
|| Body SOACS -> Bool
bodyContainsParallelism Body SOACS
fbranch Bool -> Bool -> Bool
||
Bool -> Bool
not ((TypeBase ExtShape NoUniqueness -> Bool)
-> [TypeBase ExtShape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase ExtShape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (IfAttr (TypeBase ExtShape NoUniqueness)
-> [TypeBase ExtShape NoUniqueness]
forall rt. IfAttr rt -> [rt]
ifReturns IfAttr (TypeBase ExtShape NoUniqueness)
IfAttr (BranchType SOACS)
ret)) =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
(SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> IfAttr (TypeBase ExtShape NoUniqueness) -> Names
forall a. FreeIn a => a -> Names
freeIn IfAttr (TypeBase ExtShape NoUniqueness)
IfAttr (BranchType SOACS)
ret) Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
let branch :: Branch
branch = [Int]
-> Pattern
-> SubExp
-> Body SOACS
-> Body SOACS
-> IfAttr (BranchType SOACS)
-> Branch
Branch [Int]
perm Pattern
pat SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfAttr (BranchType SOACS)
ret
Stms SOACS
stms <- ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
-> Scope SOACS -> DistNestT lore m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (KernelNest
-> Branch -> ReaderT (Scope SOACS) (DistNestT lore m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stms SOACS)
interchangeBranch KernelNest
nest' Branch
branch) Scope SOACS
types
Stms SOACS -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms SOACS -> DistNestT lore m ()
onTopLevelStms Stms SOACS
stms
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
| Just [Reduce Commutativity
comm Lambda
lam Result
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm SOACS
form,
Just BinderT SOACS (DistNestT lore m) ()
m <- Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS (DistNestT lore m) ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern
pat SubExp
w Commutativity
comm Lambda
lam ([(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT lore m) ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT lore m) ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
Scope SOACS
types <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
(()
_, Stms SOACS
bnds) <- BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Certificates
-> BinderT SOACS (DistNestT lore m) ()
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs BinderT SOACS (DistNestT lore m) ()
m) Scope SOACS
types
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc Stms SOACS
bnds
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Scatter w lam ivs as))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT lore m (Stms lore)
segmentedScatterKernel KernelNest
nest' [Int]
perm PatternT Type
Pattern
pat Certificates
cs SubExp
w Lambda lore
lam' [VName]
ivs [(SubExp, Int, VName)]
as
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Hist w ops lam as))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
segmentedHistKernel KernelNest
nest' [Int]
perm Certificates
cs SubExp
w [HistOp SOACS]
ops Lambda lore
lam' [VName]
as
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
| Just ([Scan SOACS]
scans, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form,
Scan Lambda
lam Result
nes <- [Scan SOACS] -> Scan SOACS
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan SOACS]
scans =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Lambda lore
map_lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
map_lam
Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$
KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel KernelNest
nest' [Int]
perm SubExp
w Lambda lore
lam' Lambda lore
map_lam' Result
nes [VName]
arrs DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
| Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
Reduce Commutativity
comm Lambda
lam Result
nes <- [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds,
Lambda -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda
map_lam Bool -> Bool -> Bool
|| Bool
incrementalFlattening =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Lambda lore
lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
lam
Lambda lore
map_lam' <- Lambda -> DistNestT lore m (Lambda lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Lambda -> DistNestT lore m (Lambda lore)
soacsLambda Lambda
map_lam
let comm' :: Commutativity
comm' | Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
lam = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm
KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel KernelNest
nest' [Int]
perm SubExp
w Commutativity
comm' Lambda lore
lam' Lambda lore
map_lam' Result
nes [VName]
arrs DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc lore
acc
| Bool
incrementalFlattening Bool -> Bool -> Bool
|| Maybe ([Reduce SOACS], Lambda) -> Bool
forall a. Maybe a -> Bool
isNothing (ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form) = do
Scope SOACS
scope <- (Scope lore -> Scope SOACS) -> DistNestT lore m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope lore -> Scope SOACS
forall lore. SameScope lore SOACS => Scope lore -> Scope SOACS
scopeForSOACs
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc (Stms SOACS -> DistNestT lore m (DistAcc lore))
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> DistNestT lore m (DistAcc lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) (Stms SOACS -> Stms SOACS)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m ((), Stms SOACS)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
BinderT SOACS (DistNestT lore m) ()
-> Scope SOACS -> DistNestT lore m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT lore m)))
-> SubExp
-> ScremaForm (Lore (BinderT SOACS (DistNestT lore m)))
-> [VName]
-> BinderT SOACS (DistNestT lore m) ()
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m), Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> ScremaForm (Lore m) -> [VName] -> m ()
dissectScrema Pattern (Lore (BinderT SOACS (DistNestT lore m)))
Pattern
pat SubExp
w ScremaForm (Lore (BinderT SOACS (DistNestT lore m)))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
maybeDistributeStm (Let Pattern
pat StmAux (ExpAttr SOACS)
aux (BasicOp (Replicate (Shape (SubExp
d:Result
ds)) SubExp
v))) DistAcc lore
acc
| [Type
t] <- PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern
pat = do
VName
tmp <- String -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tmp"
let rowt :: Type
rowt = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
t
newbnd :: Stm SOACS
newbnd = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpAttr SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
d (Lambda -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam) []
tmpbnd :: Stm SOACS
tmpbnd = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
tmp Type
rowt]) StmAux (ExpAttr SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
ds) SubExp
v
lam :: Lambda
lam = Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaReturnType :: [Type]
lambdaReturnType = [Type
rowt]
, lambdaParams :: [LParam SOACS]
lambdaParams = []
, lambdaBody :: Body SOACS
lambdaBody = Stms SOACS -> Result -> Body SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm Stm SOACS
tmpbnd) [VName -> SubExp
Var VName
tmp]
}
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm Stm SOACS
newbnd DistAcc lore
acc
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp Copy{})) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatternT Type
outerpat VName
arr ->
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let (Pattern [] [PatElemT (LetAttr SOACS)
pe]) StmAux (ExpAttr SOACS)
aux (BasicOp Opaque{})) DistAcc lore
acc
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> Type
forall t. Typed t => t -> Type
typeOf PatElemT Type
PatElemT (LetAttr SOACS)
pe =
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ PatternT Type
outerpat VName
arr ->
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Rearrange [Int]
perm VName
_))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
let r :: Int
r = [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
perm' :: [Int]
perm' = [Int
0..Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
r) [Int]
perm
VName
arr' <- String -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> DistNestT lore m VName)
-> String -> DistNestT lore m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr
Type
arr_t <- VName -> DistNestT lore m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList
[Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
arr' Type
arr_t]) StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr,
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm' VName
arr']
maybeDistributeStm bnd :: Stm SOACS
bnd@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Reshape ShapeChange SubExp
reshape VName
_))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
let reshape' :: ShapeChange SubExp
reshape' = (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (KernelNest -> Result
kernelNestWidths KernelNest
nest) ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++
(SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (ShapeChange SubExp -> Result
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
reshape)
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
reshape' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Rotate Result
rots VName
_))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
stm ((KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore))
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest PatternT Type
outerpat VName
arr -> do
let rots' :: Result
rots' = (SubExp -> SubExp) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExp -> SubExp
forall a b. a -> b -> a
const (SubExp -> SubExp -> SubExp) -> SubExp -> SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) (KernelNest -> Result
kernelNestWidths KernelNest
nest) Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
rots
Stms lore -> DistNestT lore m (Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms lore -> DistNestT lore m (Stms lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
outerpat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Result -> VName -> BasicOp
Rotate Result
rots' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
pat StmAux (ExpAttr SOACS)
aux (BasicOp (Update VName
arr Slice SubExp
slice (Var VName
v)))) DistAcc lore
acc
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
stm),
Just ([Int]
perm, [PatElemT Type]
pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern
pat Result
res -> do
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- [PatElemT Type] -> KernelNest -> DistNestT lore m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pat_unused KernelNest
nest
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> DistNestT lore m (Stms lore) -> DistNestT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
segmentedUpdateKernel KernelNest
nest' [Int]
perm (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) VName
arr Slice SubExp
slice VName
v
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
maybeDistributeStm (Let Pattern
pat StmAux (ExpAttr SOACS)
aux (BasicOp (Update VName
arr [DimFix SubExp
i] SubExp
v))) DistAcc lore
acc
| [Type
t] <- PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern
pat,
Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp lore -> Bool
forall lore. ExpT lore -> Bool
amortises (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc = do
let w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t
et :: Type
et = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1 Type
t
lam :: Lambda
lam = Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaParams :: [LParam SOACS]
lambdaParams = []
, lambdaReturnType :: [Type]
lambdaReturnType = [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32, Type
et]
, lambdaBody :: Body SOACS
lambdaBody = Stms SOACS -> Result -> Body SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms SOACS
forall a. Monoid a => a
mempty [SubExp
i, SubExp
v] }
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
maybeDistributeStm (Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpAttr SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda -> [VName] -> [(SubExp, Int, VName)] -> SOAC SOACS
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(SubExp, Int, VName)] -> SOAC lore
Scatter (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) Lambda
lam [] [(SubExp
w, Int
1, VName
arr)]) DistAcc lore
acc
where amortises :: ExpT lore -> Bool
amortises DoLoop{} = Bool
True
amortises Op{} = Bool
True
amortises ExpT lore
_ = Bool
False
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Concat Int
d VName
x [VName]
xs SubExp
w))) DistAcc lore
acc =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
stm DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
_, KernelNest
nest, DistAcc lore
acc') ->
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$
KernelNest -> DistNestT lore m (Maybe (Stms lore))
segmentedConcat KernelNest
nest DistNestT lore m (Maybe (Stms lore))
-> (Maybe (Stms lore) -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) Stm SOACS
stm DistAcc lore
acc PostStms lore
kernels DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ ->
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
stm DistAcc lore
acc
where segmentedConcat :: KernelNest -> DistNestT lore m (Maybe (Stms lore))
segmentedConcat KernelNest
nest =
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int
0] SubExp
w Names
forall a. Monoid a => a
mempty Names
forall a. Monoid a => a
mempty [] (VName
xVName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:[VName]
xs) ((PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
\PatternT Type
pat [(VName, SubExp)]
_ [KernelInput]
_ Result
_ (VName
x':[VName]
xs') [VName]
_ ->
let d' :: Int
d' = Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
in Stm (Lore (BinderT lore m)) -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (BinderT lore m)) -> BinderT lore m ())
-> Stm (Lore (BinderT lore m)) -> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
pat StmAux (ExpAttr lore)
StmAux (ExpAttr SOACS)
aux (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
d' VName
x' [VName]
xs' SubExp
w
maybeDistributeStm Stm SOACS
bnd DistAcc lore
acc =
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
distributeSingleUnaryStm :: (MonadFreshNames m, DistLore lore) =>
DistAcc lore -> Stm SOACS
-> (KernelNest -> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm :: DistAcc lore
-> Stm SOACS
-> (KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore))
-> DistNestT lore m (DistAcc lore)
distributeSingleUnaryStm DistAcc lore
acc Stm SOACS
bnd KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore)
f =
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
-> (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms lore
kernels, Result
res, KernelNest
nest, DistAcc lore
acc')
| Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm SOACS
bnd),
(LoopNesting
outer, [LoopNesting]
inners) <- KernelNest
nest,
[(Param Type
arr_p, VName
arr)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
KernelNest -> Names
boundInKernelNest KernelNest
nest Names -> Names -> Names
`namesIntersection` Stm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Stm SOACS
bnd
Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Names
oneName (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
arr_p) -> do
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
let outerpat :: PatternT Type
outerpat = LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
Scope lore
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc lore -> Scope lore
forall lore. DistLore lore => DistAcc lore -> Scope lore
typeEnvFromDistAcc DistAcc lore
acc') (DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall a b. (a -> b) -> a -> b
$ do
(VName
arr', Stms lore
pre_stms) <- VName -> [LoopNesting] -> DistNestT lore m (VName, Stms lore)
forall (m :: * -> *) lore lore.
(HasScope lore m, MonadFreshNames m, LetAttr lore ~ Type,
ExpAttr lore ~ ()) =>
VName -> [LoopNesting] -> m (VName, Stms lore)
repeatMissing VName
arr (LoopNesting
outerLoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
:[LoopNesting]
inners)
Stms lore
f_stms <- Stms lore
-> DistNestT lore m (Stms lore) -> DistNestT lore m (Stms lore)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms lore
pre_stms (DistNestT lore m (Stms lore) -> DistNestT lore m (Stms lore))
-> DistNestT lore m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ KernelNest
-> PatternT Type -> VName -> DistNestT lore m (Stms lore)
f KernelNest
nest PatternT Type
outerpat VName
arr'
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> Stms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ Stms lore
pre_stms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
f_stms
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
_ -> Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc Stm SOACS
bnd DistAcc lore
acc
where
repeatMissing :: VName -> [LoopNesting] -> m (VName, Stms lore)
repeatMissing VName
arr [LoopNesting]
inners = do
Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
let shapes :: [Shape]
shapes = VName -> Type -> [LoopNesting] -> [Shape]
forall shape u.
ArrayShape shape =>
VName -> TypeBase shape u -> [LoopNesting] -> [Shape]
determineRepeats VName
arr Type
arr_t [LoopNesting]
inners
if (Shape -> Bool) -> [Shape] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
==Result -> Shape
forall d. [d] -> ShapeBase d
Shape []) [Shape]
shapes then (VName, Stms lore) -> m (VName, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, Stms lore
forall a. Monoid a => a
mempty)
else do
let ([Shape]
outer_shapes, Shape
inner_shape) = [Shape] -> Type -> ([Shape], Shape)
repeatShapes [Shape]
shapes Type
arr_t
arr_t' :: Type
arr_t' = [Shape] -> Shape -> Type -> Type
repeatDims [Shape]
outer_shapes Shape
inner_shape Type
arr_t
VName
arr' <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr
(VName, Stms lore) -> m (VName, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr', Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
arr' Type
arr_t']) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Shape] -> Shape -> VName -> BasicOp
Repeat [Shape]
outer_shapes Shape
inner_shape VName
arr)
determineRepeats :: VName -> TypeBase shape u -> [LoopNesting] -> [Shape]
determineRepeats VName
arr TypeBase shape u
arr_t [LoopNesting]
nests
| ([LoopNesting]
skipped, LoopNesting
arr_nest:[LoopNesting]
nests') <- (LoopNesting -> Bool)
-> [LoopNesting] -> ([LoopNesting], [LoopNesting])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (VName -> LoopNesting -> Bool
hasInput VName
arr) [LoopNesting]
nests,
[(Param Type
arr_p, VName
_)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
arr_nest =
Result -> Shape
forall d. [d] -> ShapeBase d
Shape ((LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
skipped) Shape -> [Shape] -> [Shape]
forall a. a -> [a] -> [a]
:
VName -> TypeBase shape u -> [LoopNesting] -> [Shape]
determineRepeats (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
arr_p) (TypeBase shape u -> TypeBase shape u
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase shape u
arr_t) [LoopNesting]
nests'
| Bool
otherwise =
Result -> Shape
forall d. [d] -> ShapeBase d
Shape ((LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
nests) Shape -> [Shape] -> [Shape]
forall a. a -> [a] -> [a]
: Int -> Shape -> [Shape]
forall a. Int -> a -> [a]
replicate (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
arr_t) (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [])
hasInput :: VName -> LoopNesting -> Bool
hasInput VName
arr LoopNesting
nest
| [(Param Type
_, VName
arr')] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
nest, VName
arr' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr = Bool
True
| Bool
otherwise = Bool
False
distribute :: (MonadFreshNames m, DistLore lore) => DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute :: DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute DistAcc lore
acc =
DistAcc lore -> Maybe (DistAcc lore) -> DistAcc lore
forall a. a -> Maybe a -> a
fromMaybe DistAcc lore
acc (Maybe (DistAcc lore) -> DistAcc lore)
-> DistNestT lore m (Maybe (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc
mkSegLevel :: (MonadFreshNames m, DistLore lore) => DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel :: DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel = do
Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
-> DistNestT
lore
m
(Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m)))
-> MkSegLevel lore (DistNestT lore m)
-> DistNestT lore m (MkSegLevel lore (DistNestT lore m))
forall a b. (a -> b) -> a -> b
$ \Result
w String
desc ThreadRecommendation
r -> do
Scope lore
scope <- BinderT lore (DistNestT lore m) (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
(SegOpLevel lore
lvl, Stms lore
stms) <- DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore))
-> DistNestT lore m (SegOpLevel lore, Stms lore)
-> BinderT lore (DistNestT lore m) (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore))
-> m (SegOpLevel lore, Stms lore)
-> DistNestT lore m (SegOpLevel lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ BinderT lore m (SegOpLevel lore)
-> Scope lore -> m (SegOpLevel lore, Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl Result
w String
desc ThreadRecommendation
r) Scope lore
scope
Stms (Lore (BinderT lore (DistNestT lore m)))
-> BinderT lore (DistNestT lore m) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (DistNestT lore m)))
stms
SegOpLevel lore
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
forall (m :: * -> *) a. Monad m => a -> m a
return SegOpLevel lore
lvl
distributeIfPossible :: (MonadFreshNames m, DistLore lore) => DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible :: DistAcc lore -> DistNestT lore m (Maybe (DistAcc lore))
distributeIfPossible DistAcc lore
acc = do
Nestings
nest <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
lore
m
(Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
(Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> Nestings
-> Targets
-> Stms lore
-> DistNestT lore m (Maybe (Targets, Stms lore))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m,
MonadLogger m) =>
MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl Nestings
nest (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc) (DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) DistNestT lore m (Maybe (Targets, Stms lore))
-> (Maybe (Targets, Stms lore)
-> DistNestT lore m (Maybe (DistAcc lore)))
-> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Targets, Stms lore)
Nothing -> Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (DistAcc lore)
forall a. Maybe a
Nothing
Just (Targets
targets, Stms lore
kernel) -> do
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm Stms lore
kernel
Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore)))
-> Maybe (DistAcc lore) -> DistNestT lore m (Maybe (DistAcc lore))
forall a b. (a -> b) -> a -> b
$ DistAcc lore -> Maybe (DistAcc lore)
forall a. a -> Maybe a
Just DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc { distTargets :: Targets
distTargets = Targets
targets
, distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
}
distributeSingleStm :: (MonadFreshNames m, DistLore lore) =>
DistAcc lore -> Stm SOACS
-> DistNestT lore m (Maybe (PostStms lore,
Result,
KernelNest,
DistAcc lore))
distributeSingleStm :: DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc lore
acc Stm SOACS
bnd = do
Nestings
nest <- (DistEnv lore m -> Nestings) -> DistNestT lore m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Nestings
forall lore (m :: * -> *). DistEnv lore m -> Nestings
distNest
Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
lore
m
(Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
(Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> Nestings
-> Targets
-> Stms lore
-> DistNestT lore m (Maybe (Targets, Stms lore))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m,
MonadLogger m) =>
MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl Nestings
nest (DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc) (DistAcc lore -> Stms lore
forall lore. DistAcc lore -> Stms lore
distStms DistAcc lore
acc) DistNestT lore m (Maybe (Targets, Stms lore))
-> (Maybe (Targets, Stms lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Targets, Stms lore)
Nothing -> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. Maybe a
Nothing
Just (Targets
targets, Stms lore
distributed_bnds) ->
Nestings
-> Targets
-> Stm SOACS
-> DistNestT lore m (Maybe (Result, Targets, KernelNest))
forall (m :: * -> *) t lore.
(MonadFreshNames m, HasScope t m, Attributes lore) =>
Nestings
-> Targets -> Stm lore -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm SOACS
bnd DistNestT lore m (Maybe (Result, Targets, KernelNest))
-> (Maybe (Result, Targets, KernelNest)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Result, Targets, KernelNest)
Nothing -> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. Maybe a
Nothing
Just (Result
res, Targets
targets', KernelNest
new_kernel_nest) ->
Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore)))
-> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
forall a b. (a -> b) -> a -> b
$ (PostStms lore, Result, KernelNest, DistAcc lore)
-> Maybe (PostStms lore, Result, KernelNest, DistAcc lore)
forall a. a -> Maybe a
Just (Stms lore -> PostStms lore
forall lore. Stms lore -> PostStms lore
PostStms Stms lore
distributed_bnds,
Result
res,
KernelNest
new_kernel_nest,
DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc { distTargets :: Targets
distTargets = Targets
targets'
, distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
})
segmentedScatterKernel :: (MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName] -> [(SubExp,Int,VName)]
-> DistNestT lore m (Stms lore)
segmentedScatterKernel :: KernelNest
-> [Int]
-> PatternT Type
-> Certificates
-> SubExp
-> Lambda lore
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT lore m (Stms lore)
segmentedScatterKernel KernelNest
nest [Int]
perm PatternT Type
scatter_pat Certificates
cs SubExp
scatter_w Lambda lore
lam [VName]
ivs [(SubExp, Int, VName)]
dests = do
let nest' :: KernelNest
nest' = (PatternT Type, Result) -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (PatternT Type
scatter_pat, BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT lore -> Result) -> BodyT lore -> Result
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
(PatternT Type
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
scatter_pat Certificates
cs SubExp
scatter_w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) [VName]
ivs) KernelNest
nest
([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest'
let (Result
as_ws, [Int]
as_ns, [VName]
as) = [(SubExp, Int, VName)] -> (Result, [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
dests
[KernelInput]
as_inps <- (VName -> DistNestT lore m KernelInput)
-> [VName] -> DistNestT lore m [KernelInput]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([KernelInput] -> VName -> DistNestT lore m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
kernel_inps) [VName]
as
Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
lore
m
(Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
let rts :: [Type]
rts = ([Type] -> [Type]) -> [[Type]] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
1) ([[Type]] -> [Type]) -> [[Type]] -> [Type]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$
Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
(Result
is,Result
vs) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns) (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT lore -> Result) -> BodyT lore -> Result
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
(Result
is', Stms lore
k_body_stms) <- Binder lore Result -> DistNestT lore m (Result, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore Result -> DistNestT lore m (Result, Stms lore))
-> Binder lore Result -> DistNestT lore m (Result, Stms lore)
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ())
-> Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
Result
-> (SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
is ((SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result)
-> (SubExp -> BinderT lore (State VNameSource) SubExp)
-> Binder lore Result
forall a b. (a -> b) -> a -> b
$ \SubExp
i ->
if Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty
then SubExp -> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
i
else Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ String
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"scatter_i" (Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
i
let k_body :: KernelBody lore
k_body = BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
k_body_stms ([KernelResult] -> KernelBody lore)
-> [KernelResult] -> KernelBody lore
forall a b. (a -> b) -> a -> b
$
((SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult)
-> [(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)]
-> (SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace) ([(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult])
-> [(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$
Result
-> [KernelInput]
-> [[(SubExp, SubExp)]]
-> [(SubExp, KernelInput, [(SubExp, SubExp)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
as_ws [KernelInput]
as_inps ([[(SubExp, SubExp)]]
-> [(SubExp, KernelInput, [(SubExp, SubExp)])])
-> [[(SubExp, SubExp)]]
-> [(SubExp, KernelInput, [(SubExp, SubExp)])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([(SubExp, SubExp)] -> [[(SubExp, SubExp)]])
-> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a b. (a -> b) -> a -> b
$ Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
is' Result
vs
(SegOp (SegOpLevel lore) lore
k, Stms lore
k_bnds) <- (Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type]
rts KernelBody lore
k_body
(Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore))
-> BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinderT lore (State VNameSource) () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore))
-> BinderT lore (State VNameSource) ()
-> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
k_bnds
let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ())
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k
where findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
bad :: a
bad = String -> a
forall a. HasCallStack => String -> a
error String
"Ill-typed nested scatter encountered."
inPlaceReturn :: [(VName, SubExp)]
-> (SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace (SubExp
aw, KernelInput
inp, [(SubExp, SubExp)]
is_vs) =
Result -> VName -> [(Result, SubExp)] -> KernelResult
WriteReturns (Result -> Result
forall a. [a] -> [a]
init Result
wsResult -> Result -> Result
forall a. [a] -> [a] -> [a]
++[SubExp
aw]) (KernelInput -> VName
kernelInputArray KernelInput
inp)
[ ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids)Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++[SubExp
i], SubExp
v) | (SubExp
i,SubExp
v) <- [(SubExp, SubExp)]
is_vs ]
where ([VName]
gtids,Result
ws) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
ispace
segmentedUpdateKernel :: (MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
segmentedUpdateKernel :: KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT lore m (Stms lore)
segmentedUpdateKernel KernelNest
nest [Int]
perm Certificates
cs VName
arr Slice SubExp
slice VName
v = do
([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
let slice_dims :: Result
slice_dims = Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
[VName]
slice_gtids <- Int -> DistNestT lore m VName -> DistNestT lore m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
slice_dims) (String -> DistNestT lore m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid_slice")
let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids Result
slice_dims
((Type
res_t, KernelResult
res), Stms lore
kstms) <- Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore))
-> Binder lore (Type, KernelResult)
-> DistNestT lore m ((Type, KernelResult), Stms lore)
forall a b. (a -> b) -> a -> b
$ do
SubExp
v' <- Certificates
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) SubExp
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
String
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"v" (Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp)
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
let pexp :: SubExp -> PrimExp VName
pexp = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32
Result
slice_is <- (PrimExp VName -> BinderT lore (State VNameSource) SubExp)
-> [PrimExp VName] -> BinderT lore (State VNameSource) Result
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (String
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index" (ExpT lore -> BinderT lore (State VNameSource) SubExp)
-> (PrimExp VName -> BinderT lore (State VNameSource) (ExpT lore))
-> PrimExp VName
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> BinderT lore (State VNameSource) (ExpT lore)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp) ([PrimExp VName] -> BinderT lore (State VNameSource) Result)
-> [PrimExp VName] -> BinderT lore (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
Slice (PrimExp VName) -> [PrimExp VName] -> [PrimExp VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex SubExp -> DimIndex (PrimExp VName))
-> Slice SubExp -> Slice (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> PrimExp VName)
-> DimIndex SubExp -> DimIndex (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> PrimExp VName
pexp) Slice SubExp
slice) ([PrimExp VName] -> [PrimExp VName])
-> [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ (VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> PrimExp VName
pexp (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
let write_is :: Result
write_is = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) [(VName, SubExp)]
base_ispace Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
slice_is
arr' :: VName
arr' = VName -> (KernelInput -> VName) -> Maybe KernelInput -> VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> VName
forall a. HasCallStack => String -> a
error String
"incorrectly typed Update") KernelInput -> VName
kernelInputArray (Maybe KernelInput -> VName) -> Maybe KernelInput -> VName
forall a b. (a -> b) -> a -> b
$
(KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
Type
arr_t <- VName -> BinderT lore (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr'
Type
v_t <- SubExp -> BinderT lore (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
(Type, KernelResult) -> Binder lore (Type, KernelResult)
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
v_t,
Result -> VName -> [(Result, SubExp)] -> KernelResult
WriteReturns (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t) VName
arr' [(Result
write_is, SubExp
v')])
Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl <- DistNestT
lore
m
(Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistNestT lore m (MkSegLevel lore (DistNestT lore m))
mkSegLevel
(SegOp (SegOpLevel lore) lore
k, Stms lore
prestms) <- (Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel Result
-> String
-> ThreadRecommendation
-> BinderT lore (DistNestT lore m) (SegOpLevel lore)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] (KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore))
-> KernelBody lore
-> DistNestT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall a b. (a -> b) -> a -> b
$
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms lore
kstms [KernelResult
res]
(Stm lore -> DistNestT lore m (Stm lore))
-> Stms lore -> DistNestT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> DistNestT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> DistNestT lore m (Stms lore))
-> (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore ()
-> DistNestT lore m (Stms lore)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Binder lore () -> DistNestT lore m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> DistNestT lore m (Stms lore))
-> Binder lore () -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore (State VNameSource)))
prestms
let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ PatternT Type
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
k
segmentedHistKernel :: (MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [SOACS.HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
segmentedHistKernel :: KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda lore
-> [VName]
-> DistNestT lore m (Stms lore)
segmentedHistKernel KernelNest
nest [Int]
perm Certificates
cs SubExp
hist_w [HistOp SOACS]
ops Lambda lore
lam [VName]
arrs = do
([(VName, SubExp)]
ispace, [KernelInput]
inputs) <- KernelNest -> DistNestT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
let orig_pat :: PatternT Type
orig_pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
[HistOp SOACS]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS])
-> (HistOp SOACS -> DistNestT lore m (HistOp SOACS))
-> DistNestT lore m [HistOp SOACS]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) ->
SubExp -> SubExp -> [VName] -> Result -> Lambda -> HistOp SOACS
forall lore.
SubExp -> SubExp -> [VName] -> Result -> Lambda lore -> HistOp lore
SOACS.HistOp SubExp
num_bins SubExp
rf
([VName] -> Result -> Lambda -> HistOp SOACS)
-> DistNestT lore m [VName]
-> DistNestT lore m (Result -> Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> DistNestT lore m VName)
-> [VName] -> DistNestT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((KernelInput -> VName)
-> DistNestT lore m KernelInput -> DistNestT lore m VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KernelInput -> VName
kernelInputArray (DistNestT lore m KernelInput -> DistNestT lore m VName)
-> (VName -> DistNestT lore m KernelInput)
-> VName
-> DistNestT lore m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [KernelInput] -> VName -> DistNestT lore m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
inputs) [VName]
dests
DistNestT lore m (Result -> Lambda -> HistOp SOACS)
-> DistNestT lore m Result
-> DistNestT lore m (Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistNestT lore m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
nes
DistNestT lore m (Lambda -> HistOp SOACS)
-> DistNestT lore m Lambda -> DistNestT lore m (HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda -> DistNestT lore m Lambda
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda
op
Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
-> DistNestT
lore
m
(Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
Scope lore
scope <- DistNestT lore m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
Lambda -> Binder lore (Lambda lore)
onLambda <- (DistEnv lore m -> Lambda -> Binder lore (Lambda lore))
-> DistNestT lore m (Lambda -> Binder lore (Lambda lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
forall lore (m :: * -> *).
DistEnv lore m -> Lambda -> Binder lore (Lambda lore)
distOnSOACSLambda
let onLambda' :: Lambda -> BinderT lore m (Lambda lore)
onLambda' = ((Lambda lore, Stms lore) -> Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore)
-> BinderT lore m (Lambda lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda lore, Stms lore) -> Lambda lore
forall a b. (a, b) -> a
fst (BinderT lore m (Lambda lore, Stms lore)
-> BinderT lore m (Lambda lore))
-> (Lambda -> BinderT lore m (Lambda lore, Stms lore))
-> Lambda
-> BinderT lore m (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder lore (Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (Lambda lore)
-> BinderT lore m (Lambda lore, Stms lore))
-> (Lambda -> Binder lore (Lambda lore))
-> Lambda
-> BinderT lore m (Lambda lore, Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Binder lore (Lambda lore)
onLambda
m (Stms lore) -> DistNestT lore m (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms lore) -> DistNestT lore m (Stms lore))
-> m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ (BinderT lore m () -> Scope lore -> m (Stms lore))
-> Scope lore -> BinderT lore m () -> m (Stms lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT lore m () -> Scope lore -> m (Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (Stms lore)
runBinderT_ Scope lore
scope (BinderT lore m () -> m (Stms lore))
-> BinderT lore m () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
SegOpLevel lore
lvl <- Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
hist_w SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"seghist" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
(Lambda -> BinderT lore m (Lambda (Lore (BinderT lore m))))
-> SegOpLevel (Lore (BinderT lore m))
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore (BinderT lore m))
-> [VName]
-> BinderT lore m (Stms (Lore (BinderT lore m)))
forall (m :: * -> *).
(MonadBinder m, DistLore (Lore m)) =>
(Lambda -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore m)
-> [VName]
-> m (Stms (Lore m))
histKernel Lambda -> BinderT lore m (Lambda lore)
Lambda -> BinderT lore m (Lambda (Lore (BinderT lore m)))
onLambda' SegOpLevel lore
SegOpLevel (Lore (BinderT lore m))
lvl PatternT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops' Lambda lore
Lambda (Lore (BinderT lore m))
lam [VName]
arrs
where findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
bad :: a
bad = String -> a
forall a. HasCallStack => String -> a
error String
"Ill-typed nested Hist encountered."
histKernel :: (MonadBinder m, DistLore (Lore m)) =>
(Lambda SOACS -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type -> [(VName, SubExp)] -> [KernelInput]
-> Certificates -> SubExp -> [SOACS.HistOp SOACS]
-> Lambda (Lore m) -> [VName]
-> m (Stms (Lore m))
histKernel :: (Lambda -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore m)
-> [VName]
-> m (Stms (Lore m))
histKernel Lambda -> m (Lambda (Lore m))
onLambda SegOpLevel (Lore m)
lvl PatternT Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops Lambda (Lore m)
lam [VName]
arrs = BinderT (Lore m) m () -> m (Stms (Lore m))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT (Lore m) m () -> m (Stms (Lore m)))
-> BinderT (Lore m) m () -> m (Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
[HistOp (Lore m)]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)])
-> (HistOp SOACS -> BinderT (Lore m) m (HistOp (Lore m)))
-> BinderT (Lore m) m [HistOp (Lore m)]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) -> do
(Lambda
op', Result
nes', Shape
shape) <- Lambda -> Result -> BinderT (Lore m) m (Lambda, Result, Shape)
forall (m :: * -> *) lore.
(MonadBinder m, Lore m ~ lore) =>
Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
op Result
nes
Lambda (Lore m)
op'' <- m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m)))
-> m (Lambda (Lore m)) -> BinderT (Lore m) m (Lambda (Lore m))
forall a b. (a -> b) -> a -> b
$ Lambda -> m (Lambda (Lore m))
onLambda Lambda
op'
HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m)))
-> HistOp (Lore m) -> BinderT (Lore m) m (HistOp (Lore m))
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> Lambda (Lore m)
-> HistOp (Lore m)
forall lore.
SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> Lambda lore
-> HistOp lore
HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes' Shape
shape Lambda (Lore m)
op''
let isDest :: VName -> Bool
isDest = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem ([VName] -> VName -> Bool) -> [VName] -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> [VName]) -> [HistOp (Lore m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp (Lore m)]
ops'
inputs' :: [KernelInput]
inputs' = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (KernelInput -> Bool) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isDest (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputArray) [KernelInput]
inputs
Certificates -> BinderT (Lore m) m () -> BinderT (Lore m) m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT (Lore m) m () -> BinderT (Lore m) m ())
-> BinderT (Lore m) m () -> BinderT (Lore m) m ()
forall a b. (a -> b) -> a -> b
$
Stms (Lore m) -> BinderT (Lore m) m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore m) -> BinderT (Lore m) m ())
-> BinderT (Lore m) m (Stms (Lore m)) -> BinderT (Lore m) m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm (Lore m) -> BinderT (Lore m) m (Stm (Lore m)))
-> Stms (Lore m) -> BinderT (Lore m) m (Stms (Lore m))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm (Lore m) -> BinderT (Lore m) m (Stm (Lore m))
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms (Lore m) -> BinderT (Lore m) m (Stms (Lore m)))
-> BinderT (Lore m) m (Stms (Lore m))
-> BinderT (Lore m) m (Stms (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
SegOpLevel (Lore m)
-> Pattern (Lore m)
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp (Lore m)]
-> Lambda (Lore m)
-> [VName]
-> BinderT (Lore m) m (Stms (Lore m))
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
segHist SegOpLevel (Lore m)
lvl PatternT Type
Pattern (Lore m)
orig_pat SubExp
hist_w [(VName, SubExp)]
ispace [KernelInput]
inputs' [HistOp (Lore m)]
ops' Lambda (Lore m)
lam [VName]
arrs
determineReduceOp :: (MonadBinder m, Lore m ~ lore) =>
Lambda SOACS -> [SubExp]
-> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp :: Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
lam Result
nes =
case (SubExp -> Maybe VName) -> Result -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar Result
nes of
Just [VName]
ne_vs' -> do
let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
Result
nes' <- [VName] -> (VName -> m SubExp) -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ne_vs' ((VName -> m SubExp) -> m Result)
-> (VName -> m SubExp) -> m Result
forall a b. (a -> b) -> a -> b
$ \VName
ne_v -> do
Type
ne_v_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
ne_v
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"hist_ne" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
ne_v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
ne_v_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
Int -> DimIndex SubExp -> Slice SubExp
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (DimIndex SubExp -> Slice SubExp)
-> DimIndex SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0
(Lambda, Result, Shape) -> m (Lambda, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam', Result
nes', Shape
shape)
Maybe [VName]
Nothing ->
(Lambda, Result, Shape) -> m (Lambda, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda
lam, Result
nes, Shape
forall a. Monoid a => a
mempty)
isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap :: Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
| [Let (Pattern [] [PatElemT (LetAttr SOACS)]
pes) StmAux (ExpAttr SOACS)
_ (Op (Screma w form arrs))] <-
Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam,
Body SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
[PatElemT (LetAttr SOACS)]
pes,
Just Lambda
map_lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
[VName]
arrs [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) =
let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
map_lam
in (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape, Lambda
lam')
| Bool
otherwise = (Shape
forall a. Monoid a => a
mempty, Lambda
lam)
segmentedScanomapKernel :: (MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Lambda lore -> Lambda lore
-> [SubExp] -> [VName]
-> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel :: KernelNest
-> [Int]
-> SubExp
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
segmentedScanomapKernel KernelNest
nest [Int]
perm SubExp
segment_size Lambda lore
lam Lambda lore
map_lam Result
nes [VName]
arrs = do
Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
-> DistNestT
lore
m
(Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm SubExp
segment_size (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
lam) (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
map_lam) Result
nes [VName]
arrs ((PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
\PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ [VName]
_ -> do
let scan_op :: SegBinOp lore
scan_op = Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
Noncommutative Lambda lore
lam Result
nes' Shape
forall a. Monoid a => a
mempty
SegOpLevel lore
lvl <- Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segscan" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm lore -> BinderT lore m (Stm lore))
-> Stms lore -> BinderT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> BinderT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> BinderT lore m (Stms lore))
-> BinderT lore m (Stms lore) -> BinderT lore m (Stms lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segScan SegOpLevel lore
lvl PatternT Type
Pattern lore
pat SubExp
segment_size [SegBinOp lore
scan_op] Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
regularSegmentedRedomapKernel :: (MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp -> Commutativity
-> Lambda lore -> Lambda lore
-> [SubExp] -> [VName]
-> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel :: KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> Lambda lore
-> Lambda lore
-> Result
-> [VName]
-> DistNestT lore m (Maybe (Stms lore))
regularSegmentedRedomapKernel KernelNest
nest [Int]
perm SubExp
segment_size Commutativity
comm Lambda lore
lam Lambda lore
map_lam Result
nes [VName]
arrs = do
Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl <- (DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
-> DistNestT
lore
m
(Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv lore m
-> Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
forall lore (m :: * -> *). DistEnv lore m -> MkSegLevel lore m
distSegLevel
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm SubExp
segment_size (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
lam) (Lambda lore -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda lore
map_lam) Result
nes [VName]
arrs ((PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore)))
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$
\PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ [VName]
_ -> do
let red_op :: SegBinOp lore
red_op = Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda lore
lam Result
nes' Shape
forall a. Monoid a => a
mempty
SegOpLevel lore
lvl <- Result
-> String
-> ThreadRecommendation
-> BinderT lore m (SegOpLevel lore)
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segred" (ThreadRecommendation -> BinderT lore m (SegOpLevel lore))
-> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Stms lore -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> BinderT lore m ())
-> BinderT lore m (Stms lore) -> BinderT lore m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm lore -> BinderT lore m (Stm lore))
-> Stms lore -> BinderT lore m (Stms lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm lore -> BinderT lore m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms lore -> BinderT lore m (Stms lore))
-> BinderT lore m (Stms lore) -> BinderT lore m (Stms lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT lore m (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segRed SegOpLevel lore
lvl PatternT Type
Pattern lore
pat SubExp
segment_size [SegBinOp lore
red_op] Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
isSegmentedOp :: (MonadFreshNames m, DistLore lore) =>
KernelNest
-> [Int]
-> SubExp
-> Names -> Names
-> [SubExp] -> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp] -> [VName] -> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp :: KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ())
-> DistNestT lore m (Maybe (Stms lore))
isSegmentedOp KernelNest
nest [Int]
perm SubExp
segment_size Names
free_in_op Names
_free_in_fold_op Result
nes [VName]
arrs PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ()
m = MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore)))
-> MaybeT (DistNestT lore m) (Stms lore)
-> DistNestT lore m (Maybe (Stms lore))
forall a b. (a -> b) -> a -> b
$ do
let bound_by_nest :: Names
bound_by_nest = KernelNest -> Names
boundInKernelNest KernelNest
nest
([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest
-> MaybeT (DistNestT lore m) ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
Bool
-> MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Names
free_in_op Names -> Names -> Bool
`namesIntersect` Names
bound_by_nest) (MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ())
-> MaybeT (DistNestT lore m) () -> MaybeT (DistNestT lore m) ()
forall a b. (a -> b) -> a -> b
$
String -> MaybeT (DistNestT lore m) ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Non-fold lambda uses nest-bound parameters."
let indices :: [VName]
indices = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace
prepareNe :: SubExp -> MaybeT (DistNestT lore m) SubExp
prepareNe (Var VName
v) | VName
v VName -> Names -> Bool
`nameIn` Names
bound_by_nest =
String -> MaybeT (DistNestT lore m) SubExp
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Neutral element bound in nest"
prepareNe SubExp
ne = SubExp -> MaybeT (DistNestT lore m) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
ne
prepareArr :: VName -> MaybeT (DistNestT lore m) (BinderT lore m VName)
prepareArr VName
arr =
case (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps of
Just KernelInput
inp
| KernelInput -> Result
kernelInputIndices KernelInput
inp Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
indices ->
BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$ VName -> BinderT lore m VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> BinderT lore m VName) -> VName -> BinderT lore m VName
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
| Bool -> Bool
not (KernelInput -> VName
kernelInputArray KernelInput
inp VName -> Names -> Bool
`nameIn` Names
bound_by_nest) ->
BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> KernelInput -> BinderT lore m VName
forall (m :: * -> *).
MonadBinder m =>
[(VName, SubExp)] -> KernelInput -> m VName
replicateMissing [(VName, SubExp)]
ispace KernelInput
inp
Maybe KernelInput
Nothing | Bool -> Bool
not (VName
arr VName -> Names -> Bool
`nameIn` Names
bound_by_nest) ->
BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> BinderT lore m VName
-> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore (BinderT lore m)) -> BinderT lore m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_repd")
(BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
Maybe KernelInput
_ ->
String -> MaybeT (DistNestT lore m) (BinderT lore m VName)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Input not free or outermost."
Result
nes' <- (SubExp -> MaybeT (DistNestT lore m) SubExp)
-> Result -> MaybeT (DistNestT lore m) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> MaybeT (DistNestT lore m) SubExp
prepareNe Result
nes
[BinderT lore m VName]
mk_arrs <- (VName -> MaybeT (DistNestT lore m) (BinderT lore m VName))
-> [VName] -> MaybeT (DistNestT lore m) [BinderT lore m VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> MaybeT (DistNestT lore m) (BinderT lore m VName)
prepareArr [VName]
arrs
Scope lore
scope <- DistNestT lore m (Scope lore)
-> MaybeT (DistNestT lore m) (Scope lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift DistNestT lore m (Scope lore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore))
-> DistNestT lore m (Stms lore)
-> MaybeT (DistNestT lore m) (Stms lore)
forall a b. (a -> b) -> a -> b
$ m (Stms lore) -> DistNestT lore m (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms lore) -> DistNestT lore m (Stms lore))
-> m (Stms lore) -> DistNestT lore m (Stms lore)
forall a b. (a -> b) -> a -> b
$ (BinderT lore m () -> Scope lore -> m (Stms lore))
-> Scope lore -> BinderT lore m () -> m (Stms lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT lore m () -> Scope lore -> m (Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (Stms lore)
runBinderT_ Scope lore
scope (BinderT lore m () -> m (Stms lore))
-> BinderT lore m () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
SubExp
total_num_elements <-
String -> Exp (Lore (BinderT lore m)) -> BinderT lore m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"total_num_elements" (ExpT lore -> BinderT lore m SubExp)
-> BinderT lore m (ExpT lore) -> BinderT lore m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
BinOp
-> SubExp -> Result -> BinderT lore m (Exp (Lore (BinderT lore m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
segment_size (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)
let flatten :: VName -> BinderT lore m VName
flatten VName
arr = do
Shape
arr_shape <- Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> BinderT lore m Type -> BinderT lore m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> BinderT lore m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
let reshape :: ShapeChange SubExp
reshape = ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter [SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew SubExp
total_num_elements]
(Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
+[LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest)) Shape
arr_shape
String -> Exp (Lore (BinderT lore m)) -> BinderT lore m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_flat") (Exp (Lore (BinderT lore m)) -> BinderT lore m VName)
-> Exp (Lore (BinderT lore m)) -> BinderT lore m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
reshape VName
arr
[VName]
nested_arrs <- [BinderT lore m VName] -> BinderT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [BinderT lore m VName]
mk_arrs
[VName]
arrs' <- (VName -> BinderT lore m VName)
-> [VName] -> BinderT lore m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT lore m VName
flatten [VName]
nested_arrs
let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT lore m ()
m PatternT Type
pat [(VName, SubExp)]
ispace [KernelInput]
kernel_inps Result
nes' [VName]
nested_arrs [VName]
arrs'
where replicateMissing :: [(VName, SubExp)] -> KernelInput -> m VName
replicateMissing [(VName, SubExp)]
ispace KernelInput
inp = do
Type
t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType (VName -> m Type) -> VName -> m Type
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
let inp_is :: Result
inp_is = KernelInput -> Result
kernelInputIndices KernelInput
inp
shapes :: [Shape]
shapes = [(VName, SubExp)] -> Result -> [Shape]
forall d. [(VName, d)] -> Result -> [ShapeBase d]
determineRepeats [(VName, SubExp)]
ispace Result
inp_is
([Shape]
outer_shapes, Shape
inner_shape) = [Shape] -> Type -> ([Shape], Shape)
repeatShapes [Shape]
shapes Type
t
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"repeated" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
[Shape] -> Shape -> VName -> BasicOp
Repeat [Shape]
outer_shapes Shape
inner_shape (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
determineRepeats :: [(VName, d)] -> Result -> [ShapeBase d]
determineRepeats [(VName, d)]
ispace (SubExp
i:Result
is)
| ([(VName, d)]
skipped_ispace, [(VName, d)]
ispace') <- ((VName, d) -> Bool)
-> [(VName, d)] -> ([(VName, d)], [(VName, d)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span ((SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/=SubExp
i) (SubExp -> Bool) -> ((VName, d) -> SubExp) -> (VName, d) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> ((VName, d) -> VName) -> (VName, d) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, d) -> VName
forall a b. (a, b) -> a
fst) [(VName, d)]
ispace =
[d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape (((VName, d) -> d) -> [(VName, d)] -> [d]
forall a b. (a -> b) -> [a] -> [b]
map (VName, d) -> d
forall a b. (a, b) -> b
snd [(VName, d)]
skipped_ispace) ShapeBase d -> [ShapeBase d] -> [ShapeBase d]
forall a. a -> [a] -> [a]
: [(VName, d)] -> Result -> [ShapeBase d]
determineRepeats (Int -> [(VName, d)] -> [(VName, d)]
forall a. Int -> [a] -> [a]
drop Int
1 [(VName, d)]
ispace') Result
is
determineRepeats [(VName, d)]
ispace Result
_ =
[[d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape ([d] -> ShapeBase d) -> [d] -> ShapeBase d
forall a b. (a -> b) -> a -> b
$ ((VName, d) -> d) -> [(VName, d)] -> [d]
forall a b. (a -> b) -> [a] -> [b]
map (VName, d) -> d
forall a b. (a, b) -> b
snd [(VName, d)]
ispace]
permutationAndMissing :: PatternT Type -> [SubExp] -> Maybe ([Int], [PatElemT Type])
permutationAndMissing :: PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
pat Result
res = do
let pes :: [PatElemT Type]
pes = PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT Type
pat
([PatElemT Type]
_used,[PatElemT Type]
unused) =
(PatElemT Type -> Bool)
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> Names -> Bool
`nameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res) (VName -> Bool)
-> (PatElemT Type -> VName) -> PatElemT Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
pes
res_expanded :: Result
res_expanded = Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
unused
[Int]
perm <- (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
pes Result -> Result -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` Result
res_expanded
([Int], [PatElemT Type]) -> Maybe ([Int], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [PatElemT Type]
unused)
expandKernelNest :: MonadFreshNames m =>
[PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest :: [PatElemT Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElemT Type]
pes (LoopNesting
outer_nest, [LoopNesting]
inner_nests) = do
let outer_size :: Result
outer_size = LoopNesting -> SubExp
loopNestingWidth LoopNesting
outer_nest SubExp -> Result -> Result
forall a. a -> [a] -> [a]
:
(LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
inner_sizes :: [Result]
inner_sizes = Result -> [Result]
forall a. [a] -> [[a]]
tails (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
LoopNesting
outer_nest' <- LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
outer_nest Result
outer_size
[LoopNesting]
inner_nests' <- (LoopNesting -> Result -> m LoopNesting)
-> [LoopNesting] -> [Result] -> m [LoopNesting]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM LoopNesting -> Result -> m LoopNesting
expandWith [LoopNesting]
inner_nests [Result]
inner_sizes
KernelNest -> m KernelNest
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopNesting
outer_nest', [LoopNesting]
inner_nests')
where expandWith :: LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
nest Result
dims = do
[PatElemT Type]
pes' <- (PatElemT Type -> m (PatElemT Type))
-> [PatElemT Type] -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Result -> PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) attr.
(MonadFreshNames m, Typed attr) =>
Result -> PatElemT attr -> m (PatElemT Type)
expandPatElemWith Result
dims) [PatElemT Type]
pes
LoopNesting -> m LoopNesting
forall (m :: * -> *) a. Monad m => a -> m a
return LoopNesting
nest { loopNestingPattern :: PatternT Type
loopNestingPattern =
[PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements (LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
nest) [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type]
pes'
}
expandPatElemWith :: Result -> PatElemT attr -> m (PatElemT Type)
expandPatElemWith Result
dims PatElemT attr
pe = do
VName
name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
pe
PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT attr
pe { patElemName :: VName
patElemName = VName
name
, patElemAttr :: Type
patElemAttr = PatElemT attr -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT attr
pe Type -> Shape -> Type
`arrayOfShape` Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
dims
}
kernelOrNot :: (MonadFreshNames m, DistLore lore) =>
Certificates -> Stm SOACS -> DistAcc lore
-> PostStms lore -> DistAcc lore -> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot :: Certificates
-> Stm SOACS
-> DistAcc lore
-> PostStms lore
-> DistAcc lore
-> Maybe (Stms lore)
-> DistNestT lore m (DistAcc lore)
kernelOrNot Certificates
cs Stm SOACS
bnd DistAcc lore
acc PostStms lore
_ DistAcc lore
_ Maybe (Stms lore)
Nothing =
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs Stm SOACS
bnd) DistAcc lore
acc
kernelOrNot Certificates
cs Stm SOACS
_ DistAcc lore
_ PostStms lore
kernels DistAcc lore
acc' (Just Stms lore
bnds) = do
PostStms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms lore
kernels
Stms lore -> DistNestT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm (Stms lore -> DistNestT lore m ())
-> Stms lore -> DistNestT lore m ()
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Stm lore) -> Stms lore -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms lore
bnds
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc lore
acc'
distributeMap :: (MonadFreshNames m, DistLore lore) =>
MapLoop -> DistAcc lore
-> DistNestT lore m (DistAcc lore)
distributeMap :: MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distributeMap (MapLoop Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs) DistAcc lore
acc =
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
MonadFreshNames m =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
leavingNesting (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
PatternT Type
-> Certificates
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore a.
(Monad m, DistLore lore) =>
PatternT Type
-> Certificates
-> SubExp
-> Lambda
-> [VName]
-> DistNestT lore m a
-> DistNestT lore m a
mapNesting PatternT Type
Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs
(DistAcc lore -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc lore -> DistNestT lore m (DistAcc lore))
-> DistNestT lore m (DistAcc lore)
-> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc lore
acc' Stms SOACS
lam_bnds)
where acc' :: DistAcc lore
acc' = DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc { distTargets :: Targets
distTargets = (PatternT Type, Result) -> Targets -> Targets
pushInnerTarget
(PatternT Type
Pattern
pat, Body SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) (Targets -> Targets) -> Targets -> Targets
forall a b. (a -> b) -> a -> b
$
DistAcc lore -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc lore
acc
, distStms :: Stms lore
distStms = Stms lore
forall a. Monoid a => a
mempty
}
lam_bnds :: Stms SOACS
lam_bnds = Body SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam