{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}
module Futhark.Pass.ExtractKernels.DistributeNests
( MapLoop (..),
mapLoopStm,
bodyContainsParallelism,
lambdaContainsParallelism,
determineReduceOp,
histKernel,
DistEnv (..),
DistAcc (..),
runDistNestT,
DistNestT,
liftInner,
distributeMap,
distribute,
distributeSingleStm,
distributeMapBodyStms,
addStmsToAcc,
addStmToAcc,
permutationAndMissing,
addPostStms,
postStm,
inNesting,
)
where
import Control.Arrow (first)
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Trans.Maybe
import Control.Monad.Writer.Strict
import Data.Function ((&))
import Data.List (find, partition, tails)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map qualified as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.GPU.Op (SegVirt (..))
import Futhark.IR.SOACS (SOACS)
import Futhark.IR.SOACS qualified as SOACS
import Futhark.IR.SOACS.SOAC hiding (HistOp, histDest)
import Futhark.IR.SOACS.Simplify (simpleSOACS, simplifyStms)
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Tools
import Futhark.Transform.CopyPropagate
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Transform.Rename
import Futhark.Util
import Futhark.Util.Log
scopeForSOACs :: SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs :: forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs = forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope
data MapLoop = MapLoop (Pat Type) (StmAux ()) SubExp (Lambda SOACS) [VName]
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm (MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) =
forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
data DistEnv rep m = DistEnv
{ forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest :: Nestings,
forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope :: Scope rep,
forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
distOnTopLevelStms :: Stms SOACS -> DistNestT rep m (Stms rep),
forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distOnInnerMap ::
MapLoop ->
DistAcc rep ->
DistNestT rep m (DistAcc rep),
forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
distOnSOACSStms :: Stm SOACS -> Builder rep (Stms rep),
forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda :: Lambda SOACS -> Builder rep (Lambda rep),
forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel :: MkSegLevel rep m
}
data DistAcc rep = DistAcc
{ forall {k} (rep :: k). DistAcc rep -> Targets
distTargets :: Targets,
forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms :: Stms rep
}
data DistRes rep = DistRes
{ forall {k} (rep :: k). DistRes rep -> PostStms rep
accPostStms :: PostStms rep,
forall {k} (rep :: k). DistRes rep -> Log
accLog :: Log
}
instance Semigroup (DistRes rep) where
DistRes PostStms rep
ks1 Log
log1 <> :: DistRes rep -> DistRes rep -> DistRes rep
<> DistRes PostStms rep
ks2 Log
log2 =
forall {k} (rep :: k). PostStms rep -> Log -> DistRes rep
DistRes (PostStms rep
ks1 forall a. Semigroup a => a -> a -> a
<> PostStms rep
ks2) (Log
log1 forall a. Semigroup a => a -> a -> a
<> Log
log2)
instance Monoid (DistRes rep) where
mempty :: DistRes rep
mempty = forall {k} (rep :: k). PostStms rep -> Log -> DistRes rep
DistRes forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
newtype PostStms rep = PostStms {forall {k} (rep :: k). PostStms rep -> Stms rep
unPostStms :: Stms rep}
instance Semigroup (PostStms rep) where
PostStms Stms rep
xs <> :: PostStms rep -> PostStms rep -> PostStms rep
<> PostStms Stms rep
ys = forall {k} (rep :: k). Stms rep -> PostStms rep
PostStms forall a b. (a -> b) -> a -> b
$ Stms rep
ys forall a. Semigroup a => a -> a -> a
<> Stms rep
xs
instance Monoid (PostStms rep) where
mempty :: PostStms rep
mempty = forall {k} (rep :: k). Stms rep -> PostStms rep
PostStms forall a. Monoid a => a
mempty
typeEnvFromDistAcc :: DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc :: forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc = forall {k} (rep :: k) dec.
(LetDec rep ~ dec) =>
Pat dec -> Scope rep
scopeOfPat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (Pat Type, Result)
outerTarget forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). DistAcc rep -> Targets
distTargets
addStmsToAcc :: Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc :: forall {k} (rep :: k). Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc Stms rep
stms DistAcc rep
acc =
DistAcc rep
acc {distStms :: Stms rep
distStms = Stms rep
stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc}
addStmToAcc ::
(MonadFreshNames m, DistRep rep) =>
Stm SOACS ->
DistAcc rep ->
DistNestT rep m (DistAcc rep)
addStmToAcc :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc = do
Stm SOACS -> Builder rep (Stms rep)
onSoacs <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
distOnSOACSStms
(Stms rep
stm', Stms rep
_) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Builder rep (Stms rep)
onSoacs Stm SOACS
stm
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc {distStms :: Stms rep
distStms = Stms rep
stm' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc}
soacsLambda ::
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS ->
DistNestT rep m (Lambda rep)
soacsLambda :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam = do
Lambda SOACS -> Builder rep (Lambda rep)
onLambda <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda
forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Lambda SOACS -> Builder rep (Lambda rep)
onLambda Lambda SOACS
lam)
newtype DistNestT rep m a
= DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a)
deriving
( forall k (rep :: k) (m :: * -> *) a b.
Functor m =>
a -> DistNestT rep m b -> DistNestT rep m a
forall k (rep :: k) (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall a b. a -> DistNestT rep m b -> DistNestT rep m a
forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DistNestT rep m b -> DistNestT rep m a
$c<$ :: forall k (rep :: k) (m :: * -> *) a b.
Functor m =>
a -> DistNestT rep m b -> DistNestT rep m a
fmap :: forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b
$cfmap :: forall k (rep :: k) (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT rep m a -> DistNestT rep m b
Functor,
forall a. a -> DistNestT rep m a
forall {k} {rep :: k} {m :: * -> *}.
Applicative m =>
Functor (DistNestT rep m)
forall k (rep :: k) (m :: * -> *) a.
Applicative m =>
a -> DistNestT rep m a
forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall k (rep :: k) (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall a b.
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall a b c.
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
$c<* :: forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
*> :: forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
$c*> :: forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
liftA2 :: forall a b c.
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
$cliftA2 :: forall k (rep :: k) (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
<*> :: forall a b.
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
$c<*> :: forall k (rep :: k) (m :: * -> *) a b.
Applicative m =>
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
pure :: forall a. a -> DistNestT rep m a
$cpure :: forall k (rep :: k) (m :: * -> *) a.
Applicative m =>
a -> DistNestT rep m a
Applicative,
forall a. a -> DistNestT rep m a
forall {k} {rep :: k} {m :: * -> *}.
Monad m =>
Applicative (DistNestT rep m)
forall k (rep :: k) (m :: * -> *) a.
Monad m =>
a -> DistNestT rep m a
forall k (rep :: k) (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall k (rep :: k) (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> DistNestT rep m a
$creturn :: forall k (rep :: k) (m :: * -> *) a.
Monad m =>
a -> DistNestT rep m a
>> :: forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
$c>> :: forall k (rep :: k) (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
>>= :: forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
$c>>= :: forall k (rep :: k) (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
Monad,
MonadReader (DistEnv rep m),
MonadWriter (DistRes rep)
)
liftInner :: (LocalScope rep m, DistRep rep) => m a -> DistNestT rep m a
liftInner :: forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner m a
m = do
Scope rep
outer_scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
forall {k} (rep :: k) (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT forall a b. (a -> b) -> a -> b
$
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ do
Scope rep
inner_scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep
outer_scope forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Scope rep
inner_scope) m a
m
instance MonadFreshNames m => MonadFreshNames (DistNestT rep m) where
getNameSource :: DistNestT rep m VNameSource
getNameSource = forall {k} (rep :: k) (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
putNameSource :: VNameSource -> DistNestT rep m ()
putNameSource = forall {k} (rep :: k) (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource
instance (Monad m, ASTRep rep) => HasScope rep (DistNestT rep m) where
askScope :: DistNestT rep m (Scope rep)
askScope = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope
instance (Monad m, ASTRep rep) => LocalScope rep (DistNestT rep m) where
localScope :: forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
localScope Scope rep
types = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
DistEnv rep m
env {distScope :: Scope rep
distScope = Scope rep
types forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env}
instance Monad m => MonadLogger (DistNestT rep m) where
addLog :: Log -> DistNestT rep m ()
addLog Log
msgs = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
msgs}
runDistNestT ::
(MonadLogger m, DistRep rep) =>
DistEnv rep m ->
DistNestT rep m (DistAcc rep) ->
m (Stms rep)
runDistNestT :: forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv rep m
env (DistNestT ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
m) = do
(DistAcc rep
acc, DistRes rep
res) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
m DistEnv rep m
env
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistRes rep -> Log
accLog DistRes rep
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). PostStms rep -> Stms rep
unPostStms (forall {k} (rep :: k). DistRes rep -> PostStms rep
accPostStms DistRes rep
res) forall a. Semigroup a => a -> a -> a
<> (Pat Type, Result) -> Stms rep
identityStms (Targets -> (Pat Type, Result)
outerTarget forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc)
where
outermost :: LoopNesting
outermost = Nesting -> LoopNesting
nestingLoop forall a b. (a -> b) -> a -> b
$
case forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env of
(Nesting
nest, []) -> Nesting
nest
(Nesting
_, Nesting
nest : [Nesting]
_) -> Nesting
nest
params_to_arrs :: [(VName, VName)]
params_to_arrs =
forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first forall dec. Param dec -> VName
paramName) forall a b. (a -> b) -> a -> b
$
LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outermost
identityStms :: (Pat Type, Result) -> Stms rep
identityStms (Pat Type
rem_pat, Result
res) =
forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem Type -> SubExpRes -> Stm rep
identityStm (forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
rem_pat) Result
res
identityStm :: PatElem Type -> SubExpRes -> Stm rep
identityStm PatElem Type
pe (SubExpRes Certs
cs (Var VName
v))
| Just VName
arr <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
identityStm PatElem Type
pe (SubExpRes Certs
cs SubExp
se) =
forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. dec -> StmAux dec
defAux ()) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [LoopNesting -> SubExp
loopNestingWidth LoopNesting
outermost]) SubExp
se
addPostStms :: Monad m => PostStms rep -> DistNestT rep m ()
addPostStms :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
ks = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => a
mempty {accPostStms :: PostStms rep
accPostStms = PostStms rep
ks}
postStm :: Monad m => Stms rep -> DistNestT rep m ()
postStm :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm Stms rep
stms = forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> PostStms rep
PostStms Stms rep
stms
withStm ::
(Monad m, DistRep rep) =>
Stm SOACS ->
DistNestT rep m a ->
DistNestT rep m a
withStm :: forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
Stm SOACS -> DistNestT rep m a -> DistNestT rep m a
withStm Stm SOACS
stm = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
DistEnv rep m
env
{ distScope :: Scope rep
distScope =
forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stm SOACS
stm) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env,
distNest :: Nestings
distNest =
Names -> Nestings -> Nestings
letBindInInnerNesting Names
provided forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env
}
where
provided :: Names
provided = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm
leavingNesting ::
(MonadFreshNames m, DistRep rep) =>
DistAcc rep ->
DistNestT rep m (DistAcc rep)
leavingNesting :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
leavingNesting DistAcc rep
acc =
case Targets -> Maybe ((Pat Type, Result), Targets)
popInnerTarget forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc of
Maybe ((Pat Type, Result), Targets)
Nothing ->
forall a. HasCallStack => [Char] -> a
error [Char]
"The kernel targets list is unexpectedly small"
Just ((Pat Type
pat, Result
res), Targets
newtargets)
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc -> do
(Nesting Names
_ LoopNesting
inner, [Nesting]
_) <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest
let MapNesting Pat Type
_ StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs = LoopNesting
inner
body :: Body rep
body = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc) Result
res
used_in_body :: Names
used_in_body = forall a. FreeIn a => a -> Names
freeIn Body rep
body
([Param Type]
used_params, [VName]
used_arrs) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used_in_body) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
lam' :: Lambda rep
lam' =
Lambda
{ lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
used_params,
lambdaBody :: Body rep
lambdaBody = Body rep
body,
lambdaReturnType :: [Type]
lambdaReturnType = forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
pat
}
Stms rep
stms <-
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat Type
pat forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
used_arrs forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
lam'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DistAcc rep
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Stms rep
distStms = Stms rep
stms}
| Bool
otherwise -> do
(Nesting Names
_ LoopNesting
inner_nesting, [Nesting]
_) <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest
let w :: SubExp
w = LoopNesting -> SubExp
loopNestingWidth LoopNesting
inner_nesting
aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
inner_nesting
inps :: [(Param Type, VName)]
inps = LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
inner_nesting
remnantStm :: PatElem Type -> SubExpRes -> Stm rep
remnantStm PatElem Type
pe (SubExpRes Certs
cs (Var VName
v))
| Just (Param Type
_, VName
arr) <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param Type, VName)]
inps =
forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
remnantStm PatElem Type
pe (SubExpRes Certs
cs SubExp
se) =
forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se
stms :: Stms rep
stms =
forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem Type -> SubExpRes -> Stm rep
remnantStm (forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ DistAcc rep
acc {distTargets :: Targets
distTargets = Targets
newtargets, distStms :: Stms rep
distStms = Stms rep
stms}
mapNesting ::
(MonadFreshNames m, DistRep rep) =>
Pat Type ->
StmAux () ->
SubExp ->
Lambda SOACS ->
[VName] ->
DistNestT rep m (DistAcc rep) ->
DistNestT rep m (DistAcc rep)
mapNesting :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
mapNesting Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs DistNestT rep m (DistAcc rep)
m =
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local DistEnv rep m -> DistEnv rep m
extend forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
leavingNesting forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistNestT rep m (DistAcc rep)
m
where
nest :: Nesting
nest =
Names -> LoopNesting -> Nesting
Nesting forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
extend :: DistEnv rep m -> DistEnv rep m
extend DistEnv rep m
env =
DistEnv rep m
env
{ distNest :: Nestings
distNest = Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nest forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env,
distScope :: Scope rep
distScope = forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env
}
inNesting ::
(Monad m, DistRep rep) =>
KernelNest ->
DistNestT rep m a ->
DistNestT rep m a
inNesting :: forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
KernelNest -> DistNestT rep m a -> DistNestT rep m a
inNesting (LoopNesting
outer, [LoopNesting]
nests) = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
DistEnv rep m
env
{ distNest :: Nestings
distNest = (Nesting
inner, [Nesting]
nests'),
distScope :: Scope rep
distScope = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {k} (rep :: k).
(LParamInfo rep ~ Type) =>
LoopNesting -> Scope rep
scopeOfLoopNesting (LoopNesting
outer forall a. a -> [a] -> [a]
: [LoopNesting]
nests) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Scope rep
distScope DistEnv rep m
env
}
where
(Nesting
inner, [Nesting]
nests') =
case forall a. [a] -> [a]
reverse [LoopNesting]
nests of
[] -> (LoopNesting -> Nesting
asNesting LoopNesting
outer, [])
(LoopNesting
inner' : [LoopNesting]
ns) -> (LoopNesting -> Nesting
asNesting LoopNesting
inner', forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Nesting
asNesting forall a b. (a -> b) -> a -> b
$ LoopNesting
outer forall a. a -> [a] -> [a]
: forall a. [a] -> [a]
reverse [LoopNesting]
ns)
asNesting :: LoopNesting -> Nesting
asNesting = Names -> LoopNesting -> Nesting
Nesting forall a. Monoid a => a
mempty
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
isParallelStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms
where
isParallelStm :: Stm SOACS -> Bool
isParallelStm Stm SOACS
stm =
Exp SOACS -> Bool
isMap (forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm SOACS
stm)
Bool -> Bool -> Bool
&& Bool -> Bool
not (Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm))
isMap :: Exp SOACS -> Bool
isMap BasicOp {} = Bool
False
isMap Apply {} = Bool
False
isMap Match {} = Bool
False
isMap (DoLoop [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
body) = Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body
isMap (DoLoop [(FParam SOACS, SubExp)]
_ WhileLoop {} Body SOACS
_) = Bool
False
isMap (WithAcc [WithAccInput SOACS]
_ Lambda SOACS
lam) = Body SOACS -> Bool
bodyContainsParallelism forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
isMap Op {} = Bool
True
lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism = Body SOACS -> Bool
bodyContainsParallelism forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody
distributeMapBodyStms ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep ->
Stms SOACS ->
DistNestT rep m (DistAcc rep)
distributeMapBodyStms :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
orig_acc = forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {rep} {m :: * -> *}.
(BodyDec rep ~ (), ExpDec rep ~ (), LetDec rep ~ Type,
MonadFreshNames m, Buildable rep, HasSegOp rep, BuilderOps rep,
CanBeAliased (Op rep), LocalScope rep m) =>
DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
orig_acc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList
where
onStms :: DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc
onStms DistAcc rep
acc (Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda SOACS
lam)) : [Stm SOACS]
stms) = do
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
Stms SOACS
stream_stms <-
forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec SOACS)
pat SubExp
w [SubExp]
accs Lambda SOACS
lam [VName]
arrs) Scope SOACS
types
Stms SOACS
stream_stms' <-
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep -> Scope rep -> Stms rep -> m (Stms rep)
copyPropagateInStms SimpleOps SOACS
simpleSOACS Scope SOACS
types Stms SOACS
stream_stms) Scope SOACS
types
DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs) Stms SOACS
stream_stms') forall a. [a] -> [a] -> [a]
++ [Stm SOACS]
stms
onStms DistAcc rep
acc (Stm SOACS
stm : [Stm SOACS]
stms) =
forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
Stm SOACS -> DistNestT rep m a -> DistNestT rep m a
withStm Stm SOACS
stm forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
stm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc [Stm SOACS]
stms
onInnerMap :: Monad m => MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap MapLoop
loop DistAcc rep
acc = do
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
f <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distOnInnerMap
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
f MapLoop
loop DistAcc rep
acc
onTopLevelStms :: Monad m => Stms SOACS -> DistNestT rep m ()
onTopLevelStms :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms = do
Stms SOACS -> DistNestT rep m (Stms rep)
f <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
distOnTopLevelStms
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> DistNestT rep m (Stms rep)
f Stms SOACS
stms
maybeDistributeStm ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS ->
DistAcc rep ->
DistNestT rep m (DistAcc rep)
maybeDistributeStm :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
stm DistAcc rep
acc
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) =
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) DistAcc rep
acc
| Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec SOACS)
pat Op SOACS
soac)
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
| Just Lambda SOACS
lam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (DistAcc rep)
Nothing -> forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
Just DistAcc rep
acc' -> forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (m :: * -> *) (rep :: k).
Monad m =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap (Pat Type
-> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> MapLoop
MapLoop Pat (LetDec SOACS)
pat (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) SubExp
w Lambda SOACS
lam [VName]
arrs) DistAcc rep
acc'
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
merge form :: LoopForm SOACS
form@ForLoop {} Body SOACS
body)) DistAcc rep
acc
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` forall a. FreeIn a => a -> Names
freeIn (forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec SOACS)
pat)) (forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat),
Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
|
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
(forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn StmAux (ExpDec SOACS)
aux)
Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
Stms SOACS
stms <-
(forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest' ([Int]
-> Pat Type
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pat (LetDec SOACS)
pat [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body)
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret)) DistAcc rep
acc
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` forall a. FreeIn a => a -> Names
freeIn Pat (LetDec SOACS)
pat) (forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat),
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Body SOACS -> Bool
bodyContainsParallelism (Body SOACS
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body SOACS)]
cases)
Bool -> Bool -> Bool
|| Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType (forall rt. MatchDec rt -> [rt]
matchReturns MatchDec (BranchType SOACS)
ret)) =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
(forall a. FreeIn a => a -> Names
freeIn [SubExp]
cond forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn MatchDec (BranchType SOACS)
ret) Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
let branch :: Branch
branch = [Int]
-> Pat Type
-> [SubExp]
-> [Case (Body SOACS)]
-> Body SOACS
-> MatchDec (BranchType SOACS)
-> Branch
Branch [Int]
perm Pat (LetDec SOACS)
pat [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret
Stms SOACS
stms <-
(forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stm SOACS)
interchangeBranch KernelNest
nest' Branch
branch
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) DistAcc rep
acc
| Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
forall a. FreeIn a => a -> Names
freeIn (forall a. Int -> [a] -> [a]
drop Int
num_accs (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam))
Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
let withacc :: WithAccStm
withacc = [Int]
-> Pat Type -> [WithAccInput SOACS] -> Lambda SOACS -> WithAccStm
WithAccStm [Int]
perm Pat (LetDec SOACS)
pat [WithAccInput SOACS]
inputs Lambda SOACS
lam
Stms SOACS
stms <-
(forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
types) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
KernelNest -> WithAccStm -> m (Stm SOACS)
interchangeWithAcc KernelNest
nest' WithAccStm
withacc
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
where
num_accs :: Int
num_accs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
| Just [Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes] <- forall {k} (rep :: k). ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
Just BuilderT SOACS (DistNestT rep m) ()
m <- forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat (LetDec SOACS)
pat SubExp
w Commutativity
comm Lambda SOACS
lam forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
(()
_, Stms SOACS
stms) <- forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec SOACS)
aux BuilderT SOACS (DistNestT rep m) ()
m) Scope SOACS
types
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc Stms SOACS
stms
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(Shape, Int, VName)]
as))) DistAcc rep
acc =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
| Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
Lambda rep
lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Pat Type
-> Certs
-> SubExp
-> Lambda rep
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT rep m (Stms rep)
segmentedScatterKernel KernelNest
nest' [Int]
perm Pat (LetDec SOACS)
pat Certs
cs SubExp
w Lambda rep
lam' [VName]
ivs [(Shape, Int, VName)]
as
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Hist SubExp
w [VName]
as [HistOp SOACS]
ops Lambda SOACS
lam))) DistAcc rep
acc =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
| Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
Lambda rep
lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam
KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda rep
-> [VName]
-> DistNestT rep m (Stms rep)
segmentedHistKernel KernelNest
nest' [Int]
perm Certs
cs SubExp
w [HistOp SOACS]
ops Lambda rep
lam' [VName]
as
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pat [PatElem (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Index VName
arr Slice SubExp
slice))) DistAcc rep
acc
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice,
VName -> SubExp
Var (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec SOACS)
pe) forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall a b. (a, b) -> b
snd (Targets -> (Pat Type, Result)
innerTarget (forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc))) =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
_res, KernelNest
nest, DistAcc rep
acc') ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
segmentedGatherKernel KernelNest
nest (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
| Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
Scan Lambda SOACS
lam [SubExp]
nes <- forall {k} (rep :: k). Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
| Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
Lambda rep
map_lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
map_lam
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Lambda SOACS
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel KernelNest
nest' [Int]
perm Certs
cs SubExp
w Lambda SOACS
lam Lambda rep
map_lam' [SubExp]
nes [VName]
arrs
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot forall a. Monoid a => a
mempty Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
| Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam = do
(Stm SOACS
mapstm, Stm SOACS
redstm) <-
forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pat (LetDec SOACS)
pat (SubExp
w, [Reduce SOACS]
reds, Lambda SOACS
map_lam, [VName]
arrs)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
mapstm {stmAux :: StmAux (ExpDec SOACS)
stmAux = StmAux (ExpDec SOACS)
aux} forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
redstm
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
| Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes <- forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce SOACS]
reds =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
| Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
Lambda rep
lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam
Lambda rep
map_lam' <- forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
map_lam
let comm' :: Commutativity
comm'
| forall {k} (rep :: k). Lambda rep -> Bool
commutativeLambda Lambda SOACS
lam = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Commutativity
-> Lambda rep
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel KernelNest
nest' [Int]
perm Certs
cs SubExp
w Commutativity
comm' Lambda rep
lam' Lambda rep
map_lam' [SubExp]
nes [VName]
arrs
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot forall a. Monoid a => a
mempty Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm (Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc = do
Scope SOACS
scope <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k).
SameScope rep SOACS =>
Scope rep -> Scope SOACS
scopeForSOACs
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat (LetDec SOACS)
pat SubExp
w ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp (Replicate (Shape (SubExp
d : [SubExp]
ds)) SubExp
v))) DistAcc rep
acc
| [Type
t] <- forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec SOACS)
pat = do
VName
tmp <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tmp"
let rowt :: Type
rowt = forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t
newstm :: Stm SOACS
newstm = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
d [] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
tmpstm :: Stm SOACS
tmpstm =
forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
tmp Type
rowt]) StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
ds) SubExp
v
lam :: Lambda SOACS
lam =
Lambda
{ lambdaReturnType :: [Type]
lambdaReturnType = [Type
rowt],
lambdaParams :: [LParam SOACS]
lambdaParams = [],
lambdaBody :: Body SOACS
lambdaBody = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody (forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
tmpstm) [VName -> SubExpRes
varRes VName
tmp]
}
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
newstm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Copy VName
stm_arr))) DistAcc rep
acc =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
_ Pat Type
outerpat VName
arr ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pat [PatElem (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Opaque OpaqueOp
_ (Var VName
stm_arr)))) DistAcc rep
acc
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall t. Typed t => t -> Type
typeOf PatElem (LetDec SOACS)
pe =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
_ Pat Type
outerpat VName
arr ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rearrange [Int]
perm VName
stm_arr))) DistAcc rep
acc =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pat Type
outerpat VName
arr -> do
let r :: Int
r = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a b. (a, b) -> b
snd KernelNest
nest) forall a. Num a => a -> a -> a
+ Int
1
perm' :: [Int]
perm' = [Int
0 .. Int
r forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+ Int
r) [Int]
perm
VName
arr' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
arr
Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList
[ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
arr' Type
arr_t]) StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr,
forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm' VName
arr'
]
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Reshape ReshapeKind
k Shape
reshape VName
stm_arr))) DistAcc rep
acc =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pat Type
outerpat VName
arr -> do
let reshape' :: Shape
reshape' = forall d. [d] -> ShapeBase d
Shape (KernelNest -> [SubExp]
kernelNestWidths KernelNest
nest) forall a. Semigroup a => a -> a -> a
<> Shape
reshape
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
reshape' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rotate [SubExp]
rots VName
stm_arr))) DistAcc rep
acc =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pat Type
outerpat VName
arr -> do
let rots' :: [SubExp]
rots' = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (KernelNest -> [SubExp]
kernelNestWidths KernelNest
nest) forall a. [a] -> [a] -> [a]
++ [SubExp]
rots
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
outerpat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots' VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp (Update Safety
_ VName
arr Slice SubExp
slice (Var VName
v)))) DistAcc rep
acc
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
| forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm),
Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat (LetDec SOACS)
pat Result
res -> do
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
KernelNest
nest' <- forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DistNestT rep m (Stms rep)
segmentedUpdateKernel KernelNest
nest' [Int]
perm (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice VName
v
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ -> forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Concat Int
d (VName
x :| [VName]
xs) SubExp
w))) DistAcc rep
acc =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
_, KernelNest
nest, DistAcc rep
acc') ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$
KernelNest -> DistNestT rep m (Maybe (Stms rep))
segmentedConcat KernelNest
nest
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
where
segmentedConcat :: KernelNest -> DistNestT rep m (Maybe (Stms rep))
segmentedConcat KernelNest
nest =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int
0] forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty [] (VName
x forall a. a -> [a] -> [a]
: [VName]
xs) forall a b. (a -> b) -> a -> b
$
\Pat Type
pat [(VName, SubExp)]
_ [KernelInput]
_ [SubExp]
_ (VName
x' : [VName]
xs') ->
let d' :: Int
d' = Int
d forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a b. (a, b) -> b
snd KernelNest
nest) forall a. Num a => a -> a -> a
+ Int
1
in forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
d' (VName
x' forall a. a -> [a] -> NonEmpty a
:| [VName]
xs') SubExp
w
maybeDistributeStm Stm SOACS
stm DistAcc rep
acc =
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
distributeSingleUnaryStm ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep ->
Stm SOACS ->
VName ->
(KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)) ->
DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)
f =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
| forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm),
(LoopNesting
outer, [LoopNesting]
_) <- KernelNest
nest,
[(Param Type
arr_p, VName
arr)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
KernelNest -> Names
boundInKernelNest KernelNest
nest
Names -> Names -> Names
`namesIntersection` forall a. FreeIn a => a -> Names
freeIn Stm SOACS
stm
forall a. Eq a => a -> a -> Bool
== VName -> Names
oneName (forall dec. Param dec -> VName
paramName Param Type
arr_p),
VName -> KernelNest -> Bool
perfectlyMapped VName
arr KernelNest
nest -> do
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
let outerpat :: Pat Type
outerpat = LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst KernelNest
nest
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') forall a b. (a -> b) -> a -> b
$ do
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)
f KernelNest
nest Pat Type
outerpat VName
arr
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ -> forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
where
perfectlyMapped :: VName -> KernelNest -> Bool
perfectlyMapped VName
arr (LoopNesting
outer, [LoopNesting]
nest)
| [(Param Type
p, VName
arr')] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
VName
arr forall a. Eq a => a -> a -> Bool
== VName
arr' =
case [LoopNesting]
nest of
[] -> forall dec. Param dec -> VName
paramName Param Type
p forall a. Eq a => a -> a -> Bool
== VName
stm_arr
LoopNesting
x : [LoopNesting]
xs -> VName -> KernelNest -> Bool
perfectlyMapped (forall dec. Param dec -> VName
paramName Param Type
p) (LoopNesting
x, [LoopNesting]
xs)
| Bool
otherwise =
Bool
False
distribute ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep ->
DistNestT rep m (DistAcc rep)
distribute :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute DistAcc rep
acc =
forall a. a -> Maybe a -> a
fromMaybe DistAcc rep
acc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc
mkSegLevel ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel = do
[SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ \[SubExp]
w [Char]
desc ThreadRecommendation
r -> do
(SegOpLevel rep
lvl, Stms rep
stms) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall a b. (a -> b) -> a -> b
$ forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl [SubExp]
w [Char]
desc ThreadRecommendation
r
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegOpLevel rep
lvl
distributeIfPossible ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep ->
DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc = do
Nestings
nest <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest
[SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m,
MonadLogger m) =>
MkSegLevel rep m
-> Nestings -> Targets -> Stms rep -> m (Maybe (Targets, Stms rep))
tryDistribute [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl Nestings
nest (forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc) (forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Targets, Stms rep)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
Just (Targets
targets, Stms rep
kernel) -> do
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm Stms rep
kernel
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a. a -> Maybe a
Just
DistAcc
{ distTargets :: Targets
distTargets = Targets
targets,
distStms :: Stms rep
distStms = forall a. Monoid a => a
mempty
}
distributeSingleStm ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep ->
Stm SOACS ->
DistNestT
rep
m
( Maybe
( PostStms rep,
Result,
KernelNest,
DistAcc rep
)
)
distributeSingleStm :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm = do
Nestings
nest <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *). DistEnv rep m -> Nestings
distNest
[SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m,
MonadLogger m) =>
MkSegLevel rep m
-> Nestings -> Targets -> Stms rep -> m (Maybe (Targets, Stms rep))
tryDistribute [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl Nestings
nest (forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc) (forall {k} (rep :: k). DistAcc rep -> Stms rep
distStms DistAcc rep
acc) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Targets, Stms rep)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
Just (Targets
targets, Stms rep
distributed_stms) ->
forall {k1} {k2} (m :: * -> *) (t :: k1) (rep :: k2).
(MonadFreshNames m, HasScope t m, ASTRep rep) =>
Nestings
-> Targets -> Stm rep -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Result, Targets, KernelNest)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
Just (Result
res, Targets
targets', KernelNest
new_kernel_nest) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a. a -> Maybe a
Just
( forall {k} (rep :: k). Stms rep -> PostStms rep
PostStms Stms rep
distributed_stms,
Result
res,
KernelNest
new_kernel_nest,
DistAcc
{ distTargets :: Targets
distTargets = Targets
targets',
distStms :: Stms rep
distStms = forall a. Monoid a => a
mempty
}
)
segmentedScatterKernel ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest ->
[Int] ->
Pat Type ->
Certs ->
SubExp ->
Lambda rep ->
[VName] ->
[(Shape, Int, VName)] ->
DistNestT rep m (Stms rep)
segmentedScatterKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Pat Type
-> Certs
-> SubExp
-> Lambda rep
-> [VName]
-> [(Shape, Int, VName)]
-> DistNestT rep m (Stms rep)
segmentedScatterKernel KernelNest
nest [Int]
perm Pat Type
scatter_pat Certs
cs SubExp
scatter_w Lambda rep
lam [VName]
ivs [(Shape, Int, VName)]
dests = do
let nesting :: LoopNesting
nesting =
Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
scatter_pat (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs forall a. Monoid a => a
mempty ()) SubExp
scatter_w forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [VName]
ivs
nest' :: KernelNest
nest' =
(Pat Type, Result) -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (Pat Type
scatter_pat, forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam) LoopNesting
nesting KernelNest
nest
([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest'
let ([Shape]
as_ws, [Int]
as_ns, [VName]
as) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
indexes :: [Int]
indexes = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
as_ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
[KernelInput]
as_inps <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {f :: * -> *} {t :: * -> *}.
(Applicative f, Foldable t) =>
t KernelInput -> VName -> f KernelInput
findInput [KernelInput]
kernel_inps) [VName]
as
[SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
let rts :: [Type]
rts =
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall a. Int -> [a] -> [a]
take Int
1) forall a b. (a -> b) -> a -> b
$
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns forall a b. (a -> b) -> a -> b
$
forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
indexes) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
(Result
is, Result
vs) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
indexes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
(Result
is', Stms rep
k_body_stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
is
let k_body :: KernelBody rep
k_body =
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
as_ws [Int]
as_ns [KernelInput]
as_inps) (Result
is' forall a. [a] -> [a] -> [a]
++ Result
vs)
forall a b. a -> (a -> b) -> b
& forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)]
-> (Shape, KernelInput, [(Result, SubExpRes)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace)
forall a b. a -> (a -> b) -> b
& forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
k_body_stms
kernel_inps' :: [KernelInput]
kernel_inps' =
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn KernelBody rep
k_body) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
(SegOp (SegOpLevel rep) rep
k, Stms rep
k_stms) <- forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps' [Type]
rts KernelBody rep
k_body
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
k_stms
let pat :: Pat Type
pat =
forall dec. [PatElem dec] -> Pat dec
Pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$
forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$
LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$
forall a b. (a, b) -> a
fst KernelNest
nest
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
k
where
findInput :: t KernelInput -> VName -> f KernelInput
findInput t KernelInput
kernel_inps VName
a =
forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall {a}. a
bad forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
bad :: a
bad = forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested scatter encountered."
inPlaceReturn :: [(VName, SubExp)]
-> (Shape, KernelInput, [(Result, SubExpRes)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace (Shape
aw, KernelInput
inp, [(Result, SubExpRes)]
is_vs) =
Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
( forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
)
(forall d. [d] -> ShapeBase d
Shape (forall a. [a] -> [a]
init [SubExp]
ws forall a. [a] -> [a] -> [a]
++ forall d. ShapeBase d -> [d]
shapeDims Shape
aw))
(KernelInput -> VName
kernelInputArray KernelInput
inp)
[ (forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (forall a. [a] -> [a]
init [VName]
gtids) forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v)
| (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs
]
where
([VName]
gtids, [SubExp]
ws) = forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
ispace
segmentedUpdateKernel ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest ->
[Int] ->
Certs ->
VName ->
Slice SubExp ->
VName ->
DistNestT rep m (Stms rep)
segmentedUpdateKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DistNestT rep m (Stms rep)
segmentedUpdateKernel KernelNest
nest [Int]
perm Certs
cs VName
arr Slice SubExp
slice VName
v = do
([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
let slice_dims :: [SubExp]
slice_dims = forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
[VName]
slice_gtids <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice_dims) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")
let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids [SubExp]
slice_dims
((Type
res_t, KernelResult
res), Stms rep
kstms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
SubExp
v' <-
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"v" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
v forall a b. (a -> b) -> a -> b
$
forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
[SubExp]
slice_is <-
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") forall a b. (a -> b) -> a -> b
$
forall d. Num d => Slice d -> [d] -> [d]
fixSlice (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
let write_is :: [SubExp]
write_is = forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(VName, SubExp)]
base_ispace forall a. [a] -> [a] -> [a]
++ [SubExp]
slice_is
arr' :: VName
arr' =
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a. HasCallStack => [Char] -> a
error [Char]
"incorrectly typed Update") KernelInput -> VName
kernelInputArray forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
arr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr'
Type
v_t <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Type
v_t,
Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns forall a. Monoid a => a
mempty (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t) VName
arr' [(forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
write_is, SubExp
v')]
)
let kernel_inps' :: [KernelInput]
kernel_inps' =
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` (forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn KernelResult
res)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
[SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
(SegOp (SegOpLevel rep) rep
k, Stms rep
prestms) <-
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps' [Type
res_t] forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
kstms [KernelResult
res]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
prestms
let pat :: Pat Type
pat =
forall dec. [PatElem dec] -> Pat dec
Pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$
forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$
LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$
forall a b. (a, b) -> a
fst KernelNest
nest
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
k
segmentedGatherKernel ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest ->
Certs ->
VName ->
Slice SubExp ->
DistNestT rep m (Stms rep)
segmentedGatherKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
segmentedGatherKernel KernelNest
nest Certs
cs VName
arr Slice SubExp
slice = do
let slice_dims :: [SubExp]
slice_dims = forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
[VName]
slice_gtids <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice_dims) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")
([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids [SubExp]
slice_dims
((Type
res_t, KernelResult
res), Stms rep
kstms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
Slice SubExp
slice'' <-
forall (m :: * -> *).
MonadBuilder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
slice) forall a b. (a -> b) -> a -> b
$
Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice forall a b. (a -> b) -> a -> b
$
forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
SubExp
v' <- forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"v" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice''
Type
v_t <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
v_t, ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify forall a. Monoid a => a
mempty SubExp
v')
[SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl <- forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel
(SegOp (SegOpLevel rep) rep
k, Stms rep
prestms) <-
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
kstms [KernelResult
res]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
prestms
let pat :: Pat Type
pat = forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst KernelNest
nest
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
k
segmentedHistKernel ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest ->
[Int] ->
Certs ->
SubExp ->
[SOACS.HistOp SOACS] ->
Lambda rep ->
[VName] ->
DistNestT rep m (Stms rep)
segmentedHistKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda rep
-> [VName]
-> DistNestT rep m (Stms rep)
segmentedHistKernel KernelNest
nest [Int]
perm Certs
cs SubExp
hist_w [HistOp SOACS]
ops Lambda rep
lam [VName]
arrs = do
([(VName, SubExp)]
ispace, [KernelInput]
inputs) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
let orig_pat :: Pat Type
orig_pat =
forall dec. [PatElem dec] -> Pat dec
Pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$
forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$
LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$
forall a b. (a, b) -> a
fst KernelNest
nest
[HistOp SOACS]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) ->
forall {k} (rep :: k).
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
SOACS.HistOp Shape
num_bins SubExp
rf
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KernelInput -> VName
kernelInputArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {f :: * -> *} {t :: * -> *}.
(Applicative f, Foldable t) =>
t KernelInput -> VName -> f KernelInput
findInput [KernelInput]
inputs) [VName]
dests
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
op
[SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel
Lambda SOACS -> Builder rep (Lambda rep)
onLambda <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda
let onLambda' :: Lambda SOACS -> BuilderT rep m (Lambda rep)
onLambda' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Builder rep (Lambda rep)
onLambda
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall a b. (a -> b) -> a -> b
$
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ forall a b. (a -> b) -> a -> b
$ do
SegOpLevel rep
lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
hist_w forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"seghist" forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
(Lambda SOACS -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda SOACS -> BuilderT rep m (Lambda rep)
onLambda' SegOpLevel rep
lvl Pat Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certs
cs SubExp
hist_w [HistOp SOACS]
ops' Lambda rep
lam [VName]
arrs
where
findInput :: t KernelInput -> VName -> f KernelInput
findInput t KernelInput
kernel_inps VName
a =
forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall {a}. a
bad forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
bad :: a
bad = forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested Hist encountered."
histKernel ::
(MonadBuilder m, DistRep (Rep m)) =>
(Lambda SOACS -> m (Lambda (Rep m))) ->
SegOpLevel (Rep m) ->
Pat Type ->
[(VName, SubExp)] ->
[KernelInput] ->
Certs ->
SubExp ->
[SOACS.HistOp SOACS] ->
Lambda (Rep m) ->
[VName] ->
m (Stms (Rep m))
histKernel :: forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
(Lambda SOACS -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda SOACS -> m (Lambda (Rep m))
onLambda SegOpLevel (Rep m)
lvl Pat Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certs
cs SubExp
hist_w [HistOp SOACS]
ops Lambda (Rep m)
lam [VName]
arrs = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ forall a b. (a -> b) -> a -> b
$ do
[HistOp (Rep m)]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
(Lambda SOACS
op', [SubExp]
nes', Shape
shape) <- forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
Lambda (Rep m)
op'' <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> m (Lambda (Rep m))
onLambda Lambda SOACS
op'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda (Rep m)
op''
let isDest :: VName -> Bool
isDest = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). HistOp rep -> [VName]
histDest [HistOp (Rep m)]
ops'
inputs' :: [KernelInput]
inputs' = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isDest forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputArray) [KernelInput]
inputs
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegOpLevel (Rep m)
lvl Pat Type
orig_pat SubExp
hist_w [(VName, SubExp)]
ispace [KernelInput]
inputs' [HistOp (Rep m)]
ops' Lambda (Rep m)
lam [VName]
arrs
determineReduceOp ::
MonadBuilder m =>
Lambda SOACS ->
[SubExp] ->
m (Lambda SOACS, [SubExp], Shape)
determineReduceOp :: forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes =
case forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes of
Just [VName]
ne_vs' -> do
let (Shape
shape, Lambda SOACS
lam') = Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap Lambda SOACS
lam
[SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ne_vs' forall a b. (a -> b) -> a -> b
$ \VName
ne_v -> do
Type
ne_v_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
ne_v
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"hist_ne" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
ne_v forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
ne_v_t forall a b. (a -> b) -> a -> b
$
forall a. Int -> a -> [a]
replicate (forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) forall a b. (a -> b) -> a -> b
$
forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$
IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS
lam', [SubExp]
nes', Shape
shape)
Maybe [VName]
Nothing ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS
lam, [SubExp]
nes, forall a. Monoid a => a
mempty)
isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap Lambda SOACS
lam
| [Let (Pat [PatElem (LetDec SOACS)]
pes) StmAux (ExpDec SOACS)
_ (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))] <-
forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam,
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall {k} (rep :: k). Body rep -> Result
bodyResult (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)) forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec SOACS)]
pes,
Just Lambda SOACS
map_lam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
[VName]
arrs forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) =
let (Shape
shape, Lambda SOACS
lam') = Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap Lambda SOACS
map_lam
in (forall d. [d] -> ShapeBase d
Shape [SubExp
w] forall a. Semigroup a => a -> a -> a
<> Shape
shape, Lambda SOACS
lam')
| Bool
otherwise = (forall a. Monoid a => a
mempty, Lambda SOACS
lam)
segmentedScanomapKernel ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest ->
[Int] ->
Certs ->
SubExp ->
Lambda SOACS ->
Lambda rep ->
[SubExp] ->
[VName] ->
DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Lambda SOACS
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel KernelNest
nest [Int]
perm Certs
cs SubExp
segment_size Lambda SOACS
lam Lambda rep
map_lam [SubExp]
nes [VName]
arrs = do
[SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel
Lambda SOACS -> Builder rep (Lambda rep)
onLambda <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda
let onLambda' :: Lambda SOACS -> BuilderT rep m (Lambda rep)
onLambda' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Builder rep (Lambda rep)
onLambda
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm (forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
lam) (forall a. FreeIn a => a -> Names
freeIn Lambda rep
map_lam) [SubExp]
nes [] forall a b. (a -> b) -> a -> b
$
\Pat Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps [SubExp]
nes' [VName]
_ -> do
(Lambda SOACS
lam', [SubExp]
nes'', Shape
shape) <- forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes'
Lambda rep
lam'' <- Lambda SOACS -> BuilderT rep m (Lambda rep)
onLambda' Lambda SOACS
lam'
let scan_op :: SegBinOp rep
scan_op = forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda rep
lam'' [SubExp]
nes'' Shape
shape
SegOpLevel rep
lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
segment_size forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segscan" forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel rep
lvl Pat Type
pat Certs
cs SubExp
segment_size [SegBinOp rep
scan_op] Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
regularSegmentedRedomapKernel ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest ->
[Int] ->
Certs ->
SubExp ->
Commutativity ->
Lambda rep ->
Lambda rep ->
[SubExp] ->
[VName] ->
DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Commutativity
-> Lambda rep
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel KernelNest
nest [Int]
perm Certs
cs SubExp
segment_size Commutativity
comm Lambda rep
lam Lambda rep
map_lam [SubExp]
nes [VName]
arrs = do
[SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) (m :: * -> *).
DistEnv rep m -> MkSegLevel rep m
distSegLevel
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm (forall a. FreeIn a => a -> Names
freeIn Lambda rep
lam) (forall a. FreeIn a => a -> Names
freeIn Lambda rep
map_lam) [SubExp]
nes [] forall a b. (a -> b) -> a -> b
$
\Pat Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps [SubExp]
nes' [VName]
_ -> do
let red_op :: SegBinOp rep
red_op = forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes' forall a. Monoid a => a
mempty
SegOpLevel rep
lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
segment_size forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segred" forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel rep
lvl Pat Type
pat Certs
cs SubExp
segment_size [SegBinOp rep
red_op] Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
isSegmentedOp ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest ->
[Int] ->
Names ->
Names ->
[SubExp] ->
[VName] ->
( Pat Type ->
[(VName, SubExp)] ->
[KernelInput] ->
[SubExp] ->
[VName] ->
BuilderT rep m ()
) ->
DistNestT rep m (Maybe (Stms rep))
isSegmentedOp :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm Names
free_in_op Names
_free_in_fold_op [SubExp]
nes [VName]
arrs Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ()
m = forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT forall a b. (a -> b) -> a -> b
$ do
let bound_by_nest :: Names
bound_by_nest = KernelNest -> Names
boundInKernelNest KernelNest
nest
([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Names
free_in_op Names -> Names -> Bool
`namesIntersect` Names
bound_by_nest) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Non-fold lambda uses nest-bound parameters."
let indices :: [VName]
indices = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace
prepareNe :: SubExp -> MaybeT (DistNestT rep m) SubExp
prepareNe (Var VName
v)
| VName
v VName -> Names -> Bool
`nameIn` Names
bound_by_nest =
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Neutral element bound in nest"
prepareNe SubExp
ne = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
ne
prepareArr :: VName -> MaybeT (DistNestT rep m) (BuilderT rep m VName)
prepareArr VName
arr =
case forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
arr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps of
Just KernelInput
inp
| KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
indices ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
Maybe KernelInput
Nothing
| VName
arr VName -> Names -> Bool
`notNameIn` Names
bound_by_nest ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp
(VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_repd")
(forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
Maybe KernelInput
_ ->
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Input not free, perfectly mapped, or outermost."
[SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> MaybeT (DistNestT rep m) SubExp
prepareNe [SubExp]
nes
[BuilderT rep m VName]
mk_arrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> MaybeT (DistNestT rep m) (BuilderT rep m VName)
prepareArr [VName]
arrs
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall a b. (a -> b) -> a -> b
$
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ forall a b. (a -> b) -> a -> b
$ do
[VName]
nested_arrs <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [BuilderT rep m VName]
mk_arrs
let pat :: Pat Type
pat =
forall dec. [PatElem dec] -> Pat dec
Pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$
forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$
LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$
forall a b. (a, b) -> a
fst KernelNest
nest
Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ()
m Pat Type
pat [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [SubExp]
nes' [VName]
nested_arrs
permutationAndMissing :: Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing :: Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing (Pat [PatElem Type]
pes) Result
res = do
let ([PatElem Type]
_used, [PatElem Type]
unused) =
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn Result
res) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes
res' :: [SubExp]
res' = forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
res_expanded :: [SubExp]
res_expanded = [SubExp]
res' forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
unused
[Int]
perm <- forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` [SubExp]
res_expanded
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Int]
perm, [PatElem Type]
unused)
expandKernelNest ::
MonadFreshNames m => [PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest :: forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pes (LoopNesting
outer_nest, [LoopNesting]
inner_nests) = do
let outer_size :: [SubExp]
outer_size =
LoopNesting -> SubExp
loopNestingWidth LoopNesting
outer_nest
forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
inner_sizes :: [[SubExp]]
inner_sizes = forall a. [a] -> [[a]]
tails forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
LoopNesting
outer_nest' <- LoopNesting -> [SubExp] -> m LoopNesting
expandWith LoopNesting
outer_nest [SubExp]
outer_size
[LoopNesting]
inner_nests' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM LoopNesting -> [SubExp] -> m LoopNesting
expandWith [LoopNesting]
inner_nests [[SubExp]]
inner_sizes
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LoopNesting
outer_nest', [LoopNesting]
inner_nests')
where
expandWith :: LoopNesting -> [SubExp] -> m LoopNesting
expandWith LoopNesting
nest [SubExp]
dims = do
[PatElem Type]
pes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *} {dec}.
(MonadFreshNames m, Typed dec) =>
[SubExp] -> PatElem dec -> m (PatElem Type)
expandPatElemWith [SubExp]
dims) [PatElem Type]
pes
forall (f :: * -> *) a. Applicative f => a -> f a
pure
LoopNesting
nest
{ loopNestingPat :: Pat Type
loopNestingPat =
forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems (LoopNesting -> Pat Type
loopNestingPat LoopNesting
nest) forall a. Semigroup a => a -> a -> a
<> [PatElem Type]
pes'
}
expandPatElemWith :: [SubExp] -> PatElem dec -> m (PatElem Type)
expandPatElemWith [SubExp]
dims PatElem dec
pe = do
VName
name <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem dec
pe
forall (f :: * -> *) a. Applicative f => a -> f a
pure
PatElem dec
pe
{ patElemName :: VName
patElemName = VName
name,
patElemDec :: Type
patElemDec = forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe Type -> Shape -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape [SubExp]
dims
}
kernelOrNot ::
(MonadFreshNames m, DistRep rep) =>
Certs ->
Stm SOACS ->
DistAcc rep ->
PostStms rep ->
DistAcc rep ->
Maybe (Stms rep) ->
DistNestT rep m (DistAcc rep)
kernelOrNot :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot Certs
cs Stm SOACS
stm DistAcc rep
acc PostStms rep
_ DistAcc rep
_ Maybe (Stms rep)
Nothing =
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs Stm SOACS
stm) DistAcc rep
acc
kernelOrNot Certs
cs Stm SOACS
_ DistAcc rep
_ PostStms rep
kernels DistAcc rep
acc' (Just Stms rep
stms) = do
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs) Stms rep
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
distributeMap ::
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop ->
DistAcc rep ->
DistNestT rep m (DistAcc rep)
distributeMap :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap (MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) DistAcc rep
acc =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
mapNesting
Pat Type
pat
StmAux ()
aux
SubExp
w
Lambda SOACS
lam
[VName]
arrs
(forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc' Stms SOACS
lam_stms)
where
acc' :: DistAcc rep
acc' =
DistAcc
{ distTargets :: Targets
distTargets =
(Pat Type, Result) -> Targets -> Targets
pushInnerTarget
(Pat Type
pat, forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)
forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc rep
acc,
distStms :: Stms rep
distStms = forall a. Monoid a => a
mempty
}
lam_stms :: Stms SOACS
lam_stms = forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam