{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
module Futhark.Pass.ExtractKernels.DistributeNests
  ( KernelsStms
  , MapLoop(..)
  , mapLoopStm

  , bodyContainsParallelism
  , lambdaContainsParallelism
  , determineReduceOp
  , incrementalFlattening
  , histKernel

  , DistEnv (..)
  , DistAcc (..)
  , runDistNestT
  , DistNestT

  , distributeMap

  , distribute
  , distributeSingleStm
  , distributeMapBodyStms
  , postKernelsStms
  , addStmsToKernel
  , addStmToKernel
  , permutationAndMissing
  , addKernels
  , addKernel
  , inNesting
  )
where

import Control.Arrow (first)
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Writer.Strict
import Control.Monad.Trans.Maybe
import Data.Maybe
import Data.List (find, partition, tails)

import Futhark.Representation.SOACS
import qualified Futhark.Representation.SOACS.SOAC as SOAC
import Futhark.Representation.SOACS.Simplify (simpleSOACS)
import qualified Futhark.Representation.Kernels as Out
import Futhark.Representation.Kernels.Kernel
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.CopyPropagate
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.BlockedKernel hiding (segThread)
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Util
import Futhark.Util.Log

data MapLoop = MapLoop Pattern Certificates SubExp Lambda [VName]

mapLoopStm :: MapLoop -> Stm
mapLoopStm :: MapLoop -> Stm
mapLoopStm (MapLoop Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs) = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (Lambda -> ScremaForm SOACS
forall lore. Bindable lore => Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam) [VName]
arrs

type KernelsStms = Out.Stms Out.Kernels

data DistEnv m =
  DistEnv { DistEnv m -> Nestings
distNest :: Nestings
          , DistEnv m -> Scope Kernels
distScope :: Scope Out.Kernels
          , DistEnv m -> Stms SOACS -> DistNestT m (Stms Kernels)
distOnTopLevelStms :: Stms SOACS -> DistNestT m (Stms Out.Kernels)
          , DistEnv m -> MapLoop -> DistAcc -> DistNestT m DistAcc
distOnInnerMap :: MapLoop -> DistAcc -> DistNestT m DistAcc
          , DistEnv m -> MkSegLevel m
distSegLevel :: MkSegLevel m
          }

data DistAcc =
  DistAcc { DistAcc -> Targets
distTargets :: Targets
          , DistAcc -> Stms Kernels
distStms :: KernelsStms
          }

data DistRes =
  DistRes { DistRes -> PostKernels
accPostKernels :: PostKernels
          , DistRes -> Log
accLog :: Log
          }

instance Semigroup DistRes where
  DistRes PostKernels
ks1 Log
log1 <> :: DistRes -> DistRes -> DistRes
<> DistRes PostKernels
ks2 Log
log2 =
    PostKernels -> Log -> DistRes
DistRes (PostKernels
ks1 PostKernels -> PostKernels -> PostKernels
forall a. Semigroup a => a -> a -> a
<> PostKernels
ks2) (Log
log1 Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log2)

instance Monoid DistRes where
  mempty :: DistRes
mempty = PostKernels -> Log -> DistRes
DistRes PostKernels
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

newtype PostKernel = PostKernel { PostKernel -> Stms Kernels
unPostKernel :: KernelsStms }

newtype PostKernels = PostKernels [PostKernel]

instance Semigroup PostKernels where
  PostKernels [PostKernel]
xs <> :: PostKernels -> PostKernels -> PostKernels
<> PostKernels [PostKernel]
ys = [PostKernel] -> PostKernels
PostKernels ([PostKernel] -> PostKernels) -> [PostKernel] -> PostKernels
forall a b. (a -> b) -> a -> b
$ [PostKernel]
ys [PostKernel] -> [PostKernel] -> [PostKernel]
forall a. [a] -> [a] -> [a]
++ [PostKernel]
xs

instance Monoid PostKernels where
  mempty :: PostKernels
mempty = [PostKernel] -> PostKernels
PostKernels [PostKernel]
forall a. Monoid a => a
mempty

postKernelsStms :: PostKernels -> KernelsStms
postKernelsStms :: PostKernels -> Stms Kernels
postKernelsStms (PostKernels [PostKernel]
kernels) = [Stms Kernels] -> Stms Kernels
forall a. Monoid a => [a] -> a
mconcat ([Stms Kernels] -> Stms Kernels) -> [Stms Kernels] -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ (PostKernel -> Stms Kernels) -> [PostKernel] -> [Stms Kernels]
forall a b. (a -> b) -> [a] -> [b]
map PostKernel -> Stms Kernels
unPostKernel [PostKernel]
kernels

typeEnvFromDistAcc :: DistAcc -> Scope Out.Kernels
typeEnvFromDistAcc :: DistAcc -> Scope Kernels
typeEnvFromDistAcc = PatternT Type -> Scope Kernels
forall lore attr.
(LetAttr lore ~ attr) =>
PatternT attr -> Scope lore
scopeOfPattern (PatternT Type -> Scope Kernels)
-> (DistAcc -> PatternT Type) -> DistAcc -> Scope Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternT Type, Result) -> PatternT Type
forall a b. (a, b) -> a
fst ((PatternT Type, Result) -> PatternT Type)
-> (DistAcc -> (PatternT Type, Result)) -> DistAcc -> PatternT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (PatternT Type, Result)
Targets -> Target
outerTarget (Targets -> (PatternT Type, Result))
-> (DistAcc -> Targets) -> DistAcc -> (PatternT Type, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistAcc -> Targets
distTargets

addStmsToKernel :: KernelsStms -> DistAcc -> DistAcc
addStmsToKernel :: Stms Kernels -> DistAcc -> DistAcc
addStmsToKernel Stms Kernels
stms DistAcc
acc =
  DistAcc
acc { distStms :: Stms Kernels
distStms = Stms Kernels
stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> DistAcc -> Stms Kernels
distStms DistAcc
acc }

addStmToKernel :: Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel :: Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
stm DistAcc
acc = do
  let stm' :: Stm Kernels
stm' = Stm -> Stm Kernels
soacsStmToKernels Stm
stm
  DistAcc -> m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc { distStms :: Stms Kernels
distStms = Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
stm' Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> DistAcc -> Stms Kernels
distStms DistAcc
acc }

newtype DistNestT m a = DistNestT (ReaderT (DistEnv m) (WriterT DistRes m) a)
  deriving (a -> DistNestT m b -> DistNestT m a
(a -> b) -> DistNestT m a -> DistNestT m b
(forall a b. (a -> b) -> DistNestT m a -> DistNestT m b)
-> (forall a b. a -> DistNestT m b -> DistNestT m a)
-> Functor (DistNestT m)
forall a b. a -> DistNestT m b -> DistNestT m a
forall a b. (a -> b) -> DistNestT m a -> DistNestT m b
forall (m :: * -> *) a b.
Functor m =>
a -> DistNestT m b -> DistNestT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT m a -> DistNestT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DistNestT m b -> DistNestT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> DistNestT m b -> DistNestT m a
fmap :: (a -> b) -> DistNestT m a -> DistNestT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT m a -> DistNestT m b
Functor, Functor (DistNestT m)
a -> DistNestT m a
Functor (DistNestT m)
-> (forall a. a -> DistNestT m a)
-> (forall a b.
    DistNestT m (a -> b) -> DistNestT m a -> DistNestT m b)
-> (forall a b c.
    (a -> b -> c) -> DistNestT m a -> DistNestT m b -> DistNestT m c)
-> (forall a b. DistNestT m a -> DistNestT m b -> DistNestT m b)
-> (forall a b. DistNestT m a -> DistNestT m b -> DistNestT m a)
-> Applicative (DistNestT m)
DistNestT m a -> DistNestT m b -> DistNestT m b
DistNestT m a -> DistNestT m b -> DistNestT m a
DistNestT m (a -> b) -> DistNestT m a -> DistNestT m b
(a -> b -> c) -> DistNestT m a -> DistNestT m b -> DistNestT m c
forall a. a -> DistNestT m a
forall a b. DistNestT m a -> DistNestT m b -> DistNestT m a
forall a b. DistNestT m a -> DistNestT m b -> DistNestT m b
forall a b. DistNestT m (a -> b) -> DistNestT m a -> DistNestT m b
forall a b c.
(a -> b -> c) -> DistNestT m a -> DistNestT m b -> DistNestT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (m :: * -> *). Applicative m => Functor (DistNestT m)
forall (m :: * -> *) a. Applicative m => a -> DistNestT m a
forall (m :: * -> *) a b.
Applicative m =>
DistNestT m a -> DistNestT m b -> DistNestT m a
forall (m :: * -> *) a b.
Applicative m =>
DistNestT m a -> DistNestT m b -> DistNestT m b
forall (m :: * -> *) a b.
Applicative m =>
DistNestT m (a -> b) -> DistNestT m a -> DistNestT m b
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> DistNestT m a -> DistNestT m b -> DistNestT m c
<* :: DistNestT m a -> DistNestT m b -> DistNestT m a
$c<* :: forall (m :: * -> *) a b.
Applicative m =>
DistNestT m a -> DistNestT m b -> DistNestT m a
*> :: DistNestT m a -> DistNestT m b -> DistNestT m b
$c*> :: forall (m :: * -> *) a b.
Applicative m =>
DistNestT m a -> DistNestT m b -> DistNestT m b
liftA2 :: (a -> b -> c) -> DistNestT m a -> DistNestT m b -> DistNestT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> DistNestT m a -> DistNestT m b -> DistNestT m c
<*> :: DistNestT m (a -> b) -> DistNestT m a -> DistNestT m b
$c<*> :: forall (m :: * -> *) a b.
Applicative m =>
DistNestT m (a -> b) -> DistNestT m a -> DistNestT m b
pure :: a -> DistNestT m a
$cpure :: forall (m :: * -> *) a. Applicative m => a -> DistNestT m a
$cp1Applicative :: forall (m :: * -> *). Applicative m => Functor (DistNestT m)
Applicative, Applicative (DistNestT m)
a -> DistNestT m a
Applicative (DistNestT m)
-> (forall a b.
    DistNestT m a -> (a -> DistNestT m b) -> DistNestT m b)
-> (forall a b. DistNestT m a -> DistNestT m b -> DistNestT m b)
-> (forall a. a -> DistNestT m a)
-> Monad (DistNestT m)
DistNestT m a -> (a -> DistNestT m b) -> DistNestT m b
DistNestT m a -> DistNestT m b -> DistNestT m b
forall a. a -> DistNestT m a
forall a b. DistNestT m a -> DistNestT m b -> DistNestT m b
forall a b. DistNestT m a -> (a -> DistNestT m b) -> DistNestT m b
forall (m :: * -> *). Monad m => Applicative (DistNestT m)
forall (m :: * -> *) a. Monad m => a -> DistNestT m a
forall (m :: * -> *) a b.
Monad m =>
DistNestT m a -> DistNestT m b -> DistNestT m b
forall (m :: * -> *) a b.
Monad m =>
DistNestT m a -> (a -> DistNestT m b) -> DistNestT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DistNestT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> DistNestT m a
>> :: DistNestT m a -> DistNestT m b -> DistNestT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
DistNestT m a -> DistNestT m b -> DistNestT m b
>>= :: DistNestT m a -> (a -> DistNestT m b) -> DistNestT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
DistNestT m a -> (a -> DistNestT m b) -> DistNestT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (DistNestT m)
Monad,
            MonadReader (DistEnv m),
            MonadWriter DistRes)

instance MonadTrans DistNestT where
  lift :: m a -> DistNestT m a
lift = ReaderT (DistEnv m) (WriterT DistRes m) a -> DistNestT m a
forall (m :: * -> *) a.
ReaderT (DistEnv m) (WriterT DistRes m) a -> DistNestT m a
DistNestT (ReaderT (DistEnv m) (WriterT DistRes m) a -> DistNestT m a)
-> (m a -> ReaderT (DistEnv m) (WriterT DistRes m) a)
-> m a
-> DistNestT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT DistRes m a -> ReaderT (DistEnv m) (WriterT DistRes m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT DistRes m a -> ReaderT (DistEnv m) (WriterT DistRes m) a)
-> (m a -> WriterT DistRes m a)
-> m a
-> ReaderT (DistEnv m) (WriterT DistRes m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> WriterT DistRes m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance MonadFreshNames m => MonadFreshNames (DistNestT m) where
  getNameSource :: DistNestT m VNameSource
getNameSource = ReaderT (DistEnv m) (WriterT DistRes m) VNameSource
-> DistNestT m VNameSource
forall (m :: * -> *) a.
ReaderT (DistEnv m) (WriterT DistRes m) a -> DistNestT m a
DistNestT (ReaderT (DistEnv m) (WriterT DistRes m) VNameSource
 -> DistNestT m VNameSource)
-> ReaderT (DistEnv m) (WriterT DistRes m) VNameSource
-> DistNestT m VNameSource
forall a b. (a -> b) -> a -> b
$ WriterT DistRes m VNameSource
-> ReaderT (DistEnv m) (WriterT DistRes m) VNameSource
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift WriterT DistRes m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  putNameSource :: VNameSource -> DistNestT m ()
putNameSource = ReaderT (DistEnv m) (WriterT DistRes m) () -> DistNestT m ()
forall (m :: * -> *) a.
ReaderT (DistEnv m) (WriterT DistRes m) a -> DistNestT m a
DistNestT (ReaderT (DistEnv m) (WriterT DistRes m) () -> DistNestT m ())
-> (VNameSource -> ReaderT (DistEnv m) (WriterT DistRes m) ())
-> VNameSource
-> DistNestT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT DistRes m () -> ReaderT (DistEnv m) (WriterT DistRes m) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT DistRes m ()
 -> ReaderT (DistEnv m) (WriterT DistRes m) ())
-> (VNameSource -> WriterT DistRes m ())
-> VNameSource
-> ReaderT (DistEnv m) (WriterT DistRes m) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> WriterT DistRes m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource

instance Monad m => HasScope Out.Kernels (DistNestT m) where
  askScope :: DistNestT m (Scope Kernels)
askScope = (DistEnv m -> Scope Kernels) -> DistNestT m (Scope Kernels)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv m -> Scope Kernels
forall (m :: * -> *). DistEnv m -> Scope Kernels
distScope

instance Monad m => LocalScope Out.Kernels (DistNestT m) where
  localScope :: Scope Kernels -> DistNestT m a -> DistNestT m a
localScope Scope Kernels
types = (DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a)
-> (DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a
forall a b. (a -> b) -> a -> b
$ \DistEnv m
env ->
    DistEnv m
env { distScope :: Scope Kernels
distScope = Scope Kernels
types Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> DistEnv m -> Scope Kernels
forall (m :: * -> *). DistEnv m -> Scope Kernels
distScope DistEnv m
env }

instance Monad m => MonadLogger (DistNestT m) where
  addLog :: Log -> DistNestT m ()
addLog Log
msgs = DistRes -> DistNestT m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell DistRes
forall a. Monoid a => a
mempty { accLog :: Log
accLog = Log
msgs }

runDistNestT :: MonadLogger m =>
                DistEnv m -> DistNestT m DistAcc -> m (Out.Stms Out.Kernels)
runDistNestT :: DistEnv m -> DistNestT m DistAcc -> m (Stms Kernels)
runDistNestT DistEnv m
env (DistNestT ReaderT (DistEnv m) (WriterT DistRes m) DistAcc
m) = do
  (DistAcc
acc, DistRes
res) <- WriterT DistRes m DistAcc -> m (DistAcc, DistRes)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT DistRes m DistAcc -> m (DistAcc, DistRes))
-> WriterT DistRes m DistAcc -> m (DistAcc, DistRes)
forall a b. (a -> b) -> a -> b
$ ReaderT (DistEnv m) (WriterT DistRes m) DistAcc
-> DistEnv m -> WriterT DistRes m DistAcc
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (DistEnv m) (WriterT DistRes m) DistAcc
m DistEnv m
env
  Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog (Log -> m ()) -> Log -> m ()
forall a b. (a -> b) -> a -> b
$ DistRes -> Log
accLog DistRes
res
  -- There may be a few final targets remaining - these correspond to
  -- arrays that are identity mapped, and must have statements
  -- inserted here.
  Stms Kernels -> m (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels -> m (Stms Kernels))
-> Stms Kernels -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
    PostKernels -> Stms Kernels
postKernelsStms (DistRes -> PostKernels
accPostKernels DistRes
res) Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>
    (PatternT Type, Result) -> Stms Kernels
identityStms (Targets -> Target
outerTarget (Targets -> Target) -> Targets -> Target
forall a b. (a -> b) -> a -> b
$ DistAcc -> Targets
distTargets DistAcc
acc)
  where outermost :: LoopNesting
outermost = Nesting -> LoopNesting
nestingLoop (Nesting -> LoopNesting) -> Nesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
                    case DistEnv m -> Nestings
forall (m :: * -> *). DistEnv m -> Nestings
distNest DistEnv m
env of (Nesting
nest, []) -> Nesting
nest
                                         (Nesting
_, Nesting
nest : [Nesting]
_) -> Nesting
nest
        params_to_arrs :: [(VName, VName)]
params_to_arrs = ((Param Type, VName) -> (VName, VName))
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param Type -> VName) -> (Param Type, VName) -> (VName, VName)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Param Type -> VName
forall attr. Param attr -> VName
paramName) ([(Param Type, VName)] -> [(VName, VName)])
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$
                         LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outermost

        identityStms :: (PatternT Type, Result) -> Stms Kernels
identityStms (PatternT Type
rem_pat, Result
res) =
          [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels) -> [Stm Kernels] -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> SubExp -> Stm Kernels)
-> [PatElemT Type] -> Result -> [Stm Kernels]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT Type -> SubExp -> Stm Kernels
identityStm (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT Type
rem_pat) Result
res
        identityStm :: PatElemT Type -> SubExp -> Stm Kernels
identityStm PatElemT Type
pe (Var VName
v)
          | Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
              Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp Kernels
forall lore. VName -> BasicOp lore
Copy VName
arr
        identityStm PatElemT Type
pe SubExp
se =
          Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$
          Shape -> SubExp -> BasicOp Kernels
forall lore. Shape -> SubExp -> BasicOp lore
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [LoopNesting -> SubExp
loopNestingWidth LoopNesting
outermost]) SubExp
se

addKernels :: Monad m => PostKernels -> DistNestT m ()
addKernels :: PostKernels -> DistNestT m ()
addKernels PostKernels
ks = DistRes -> DistNestT m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (DistRes -> DistNestT m ()) -> DistRes -> DistNestT m ()
forall a b. (a -> b) -> a -> b
$ DistRes
forall a. Monoid a => a
mempty { accPostKernels :: PostKernels
accPostKernels = PostKernels
ks }

addKernel :: Monad m => KernelsStms -> DistNestT m ()
addKernel :: Stms Kernels -> DistNestT m ()
addKernel Stms Kernels
bnds = PostKernels -> DistNestT m ()
forall (m :: * -> *). Monad m => PostKernels -> DistNestT m ()
addKernels (PostKernels -> DistNestT m ()) -> PostKernels -> DistNestT m ()
forall a b. (a -> b) -> a -> b
$ [PostKernel] -> PostKernels
PostKernels [Stms Kernels -> PostKernel
PostKernel Stms Kernels
bnds]

withStm :: Monad m => Stm -> DistNestT m a -> DistNestT m a
withStm :: Stm -> DistNestT m a -> DistNestT m a
withStm Stm
stm = (DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a)
-> (DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a
forall a b. (a -> b) -> a -> b
$ \DistEnv m
env ->
  DistEnv m
env { distScope :: Scope Kernels
distScope =
          Scope SOACS -> Scope Kernels
scopeForKernels (Stm -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm
stm) Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> DistEnv m -> Scope Kernels
forall (m :: * -> *). DistEnv m -> Scope Kernels
distScope DistEnv m
env
      , distNest :: Nestings
distNest =
          Names -> Nestings -> Nestings
letBindInInnerNesting Names
provided (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$
          DistEnv m -> Nestings
forall (m :: * -> *). DistEnv m -> Nestings
distNest DistEnv m
env
      }
  where provided :: Names
provided = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm
stm

mapNesting :: Monad m =>
              Pattern -> Certificates -> SubExp -> Lambda -> [VName]
           -> DistNestT m a
           -> DistNestT m a
mapNesting :: Pattern
-> Certificates
-> SubExp
-> Lambda
-> [VName]
-> DistNestT m a
-> DistNestT m a
mapNesting Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs = (DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a)
-> (DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a
forall a b. (a -> b) -> a -> b
$ \DistEnv m
env ->
  DistEnv m
env { distNest :: Nestings
distNest = Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nest (Nestings -> Nestings) -> Nestings -> Nestings
forall a b. (a -> b) -> a -> b
$ DistEnv m -> Nestings
forall (m :: * -> *). DistEnv m -> Nestings
distNest DistEnv m
env
      , distScope :: Scope Kernels
distScope =  Scope SOACS -> Scope Kernels
scopeForKernels (Lambda -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda
lam) Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> DistEnv m -> Scope Kernels
forall (m :: * -> *). DistEnv m -> Scope Kernels
distScope DistEnv m
env
      }
  where nest :: Nesting
nest = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty (LoopNesting -> Nesting) -> LoopNesting -> Nesting
forall a b. (a -> b) -> a -> b
$
               Pattern Kernels
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pattern
Pattern Kernels
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$
               [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs

inNesting :: Monad m =>
             KernelNest -> DistNestT m a -> DistNestT m a
inNesting :: KernelNest -> DistNestT m a -> DistNestT m a
inNesting (LoopNesting
outer, [LoopNesting]
nests) = (DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a)
-> (DistEnv m -> DistEnv m) -> DistNestT m a -> DistNestT m a
forall a b. (a -> b) -> a -> b
$ \DistEnv m
env ->
  DistEnv m
env { distNest :: Nestings
distNest = (Nesting
inner, [Nesting]
nests')
      , distScope :: Scope Kernels
distScope =  [Scope Kernels] -> Scope Kernels
forall a. Monoid a => [a] -> a
mconcat ((LoopNesting -> Scope Kernels) -> [LoopNesting] -> [Scope Kernels]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf ([LoopNesting] -> [Scope Kernels])
-> [LoopNesting] -> [Scope Kernels]
forall a b. (a -> b) -> a -> b
$ LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests) Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> DistEnv m -> Scope Kernels
forall (m :: * -> *). DistEnv m -> Scope Kernels
distScope DistEnv m
env
      }
  where (Nesting
inner, [Nesting]
nests') =
          case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
            []           -> (LoopNesting -> Nesting
asNesting LoopNesting
outer, [])
            (LoopNesting
inner' : [LoopNesting]
ns) -> (LoopNesting -> Nesting
asNesting LoopNesting
inner', (LoopNesting -> Nesting) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Nesting
asNesting ([LoopNesting] -> [Nesting]) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> a -> b
$ LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
ns)
        asNesting :: LoopNesting -> Nesting
asNesting = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty

bodyContainsParallelism :: Body -> Bool
bodyContainsParallelism :: Body -> Bool
bodyContainsParallelism = (Stm -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp SOACS -> Bool
forall lore. ExpT lore -> Bool
isMap (Exp SOACS -> Bool) -> (Stm -> Exp SOACS) -> Stm -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> Exp SOACS
forall lore. Stm lore -> Exp lore
stmExp) (Stms SOACS -> Bool) -> (Body -> Stms SOACS) -> Body -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms
  where isMap :: ExpT lore -> Bool
isMap Op{} = Bool
True
        isMap ExpT lore
_ = Bool
False

lambdaContainsParallelism :: Lambda -> Bool
lambdaContainsParallelism :: Lambda -> Bool
lambdaContainsParallelism = Body -> Bool
bodyContainsParallelism (Body -> Bool) -> (Lambda -> Body) -> Lambda -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Body
forall lore. LambdaT lore -> BodyT lore
lambdaBody

-- Enable if you want the cool new versioned code.  Beware: may be
-- slower in practice.  Caveat emptor (and you are the emptor).
incrementalFlattening :: Bool
incrementalFlattening :: Bool
incrementalFlattening = Maybe String -> Bool
forall a. Maybe a -> Bool
isJust (Maybe String -> Bool) -> Maybe String -> Bool
forall a b. (a -> b) -> a -> b
$ String -> [(String, String)] -> Maybe String
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup String
"FUTHARK_INCREMENTAL_FLATTENING" [(String, String)]
unixEnvironment

leavingNesting :: Monad m => MapLoop -> DistAcc -> DistNestT m DistAcc
leavingNesting :: MapLoop -> DistAcc -> DistNestT m DistAcc
leavingNesting (MapLoop Pattern
_ Certificates
cs SubExp
w Lambda
lam [VName]
arrs) DistAcc
acc =
  case Targets -> Maybe (Target, Targets)
popInnerTarget (Targets -> Maybe (Target, Targets))
-> Targets -> Maybe (Target, Targets)
forall a b. (a -> b) -> a -> b
$ DistAcc -> Targets
distTargets DistAcc
acc of
   Maybe (Target, Targets)
Nothing ->
     String -> DistNestT m DistAcc
forall a. HasCallStack => String -> a
error String
"The kernel targets list is unexpectedly small"
   Just ((Pattern Kernels
pat,Result
res), Targets
newtargets) -> do
     let acc' :: DistAcc
acc' = DistAcc
acc { distTargets :: Targets
distTargets = Targets
newtargets }
     if Stms Kernels -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Stms Kernels -> Bool) -> Stms Kernels -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc -> Stms Kernels
distStms DistAcc
acc'
       then DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc'
       else do let body :: BodyT Kernels
body = BodyAttr Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () (DistAcc -> Stms Kernels
distStms DistAcc
acc') Result
res
                   used_in_body :: Names
used_in_body = BodyT Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT Kernels
body
                   ([Param Type]
used_params, [VName]
used_arrs) =
                     [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
                     ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used_in_body) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> [(Param Type, VName)])
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a b. (a -> b) -> a -> b
$
                     [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
                   lam' :: LambdaT Kernels
lam' = Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaParams :: [LParam Kernels]
lambdaParams = [Param Type]
[LParam Kernels]
used_params
                                 , lambdaBody :: BodyT Kernels
lambdaBody = BodyT Kernels
body
                                 , lambdaReturnType :: [Type]
lambdaReturnType = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern Kernels
pat
                                 }
               let stms :: Stms Kernels
stms = Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels) -> Stm Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> Exp Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> Exp Kernels) -> Op Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$
                          SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm Kernels -> [VName] -> SOAC Kernels
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (LambdaT Kernels -> ScremaForm Kernels
forall lore. Bindable lore => Lambda lore -> ScremaForm lore
mapSOAC LambdaT Kernels
lam') [VName]
used_arrs
               DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return (DistAcc -> DistNestT m DistAcc) -> DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> DistAcc -> DistAcc
addStmsToKernel Stms Kernels
stms DistAcc
acc' { distStms :: Stms Kernels
distStms = Stms Kernels
forall a. Monoid a => a
mempty }

distributeMapBodyStms :: MonadFreshNames m => DistAcc -> Stms SOACS -> DistNestT m DistAcc
distributeMapBodyStms :: DistAcc -> Stms SOACS -> DistNestT m DistAcc
distributeMapBodyStms DistAcc
orig_acc = DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> DistNestT m DistAcc
distribute (DistAcc -> DistNestT m DistAcc)
-> (Stms SOACS -> DistNestT m DistAcc)
-> Stms SOACS
-> DistNestT m DistAcc
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< DistAcc -> [Stm] -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> [Stm] -> DistNestT m DistAcc
onStms DistAcc
orig_acc ([Stm] -> DistNestT m DistAcc)
-> (Stms SOACS -> [Stm]) -> Stms SOACS -> DistNestT m DistAcc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList
  where
    onStms :: DistAcc -> [Stm] -> DistNestT m DistAcc
onStms DistAcc
acc [] = DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc

    onStms DistAcc
acc (Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Stream w (Sequential accs) lam arrs)):[Stm]
stms) = do
      Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistNestT m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
      Stms SOACS
stream_stms <-
        ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Stms SOACS)
-> DistNestT m ((), Stms SOACS) -> DistNestT m (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS (DistNestT m) ()
-> Scope SOACS -> DistNestT m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT m)))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS (DistNestT m)))
-> [VName]
-> BinderT SOACS (DistNestT m) ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS (DistNestT m)))
Pattern
pat SubExp
w Result
accs LambdaT (Lore (BinderT SOACS (DistNestT m)))
Lambda
lam [VName]
arrs) Scope SOACS
types
      (SymbolTable (Wise SOACS)
_, Stms SOACS
stream_stms') <-
        ReaderT
  (Scope SOACS) (DistNestT m) (SymbolTable (Wise SOACS), Stms SOACS)
-> Scope SOACS
-> DistNestT m (SymbolTable (Wise SOACS), Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (SimpleOps SOACS
-> Scope SOACS
-> Stms SOACS
-> ReaderT
     (Scope SOACS) (DistNestT m) (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
copyPropagateInStms SimpleOps SOACS
simpleSOACS Scope SOACS
types Stms SOACS
stream_stms) Scope SOACS
types
      DistAcc -> [Stm] -> DistNestT m DistAcc
onStms DistAcc
acc ([Stm] -> DistNestT m DistAcc) -> [Stm] -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList ((Stm -> Stm) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms SOACS
stream_stms') [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [Stm]
stms

    onStms DistAcc
acc (Stm
stm:[Stm]
stms) =
      -- It is important that stm is in scope if 'maybeDistributeStm'
      -- wants to distribute, even if this causes the slightly silly
      -- situation that stm is in scope of itself.
      Stm -> DistNestT m DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a.
Monad m =>
Stm -> DistNestT m a -> DistNestT m a
withStm Stm
stm (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
Stm -> DistAcc -> DistNestT m DistAcc
maybeDistributeStm Stm
stm (DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc -> [Stm] -> DistNestT m DistAcc
onStms DistAcc
acc [Stm]
stms

onInnerMap :: Monad m => MapLoop -> DistAcc -> DistNestT m DistAcc
onInnerMap :: MapLoop -> DistAcc -> DistNestT m DistAcc
onInnerMap MapLoop
loop DistAcc
acc = do
  MapLoop -> DistAcc -> DistNestT m DistAcc
f <- (DistEnv m -> MapLoop -> DistAcc -> DistNestT m DistAcc)
-> DistNestT m (MapLoop -> DistAcc -> DistNestT m DistAcc)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv m -> MapLoop -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
DistEnv m -> MapLoop -> DistAcc -> DistNestT m DistAcc
distOnInnerMap
  MapLoop -> DistAcc -> DistNestT m DistAcc
f MapLoop
loop DistAcc
acc

onTopLevelStms :: Monad m => Stms SOACS -> DistNestT m ()
onTopLevelStms :: Stms SOACS -> DistNestT m ()
onTopLevelStms Stms SOACS
stms = do
  Stms SOACS -> DistNestT m (Stms Kernels)
f <- (DistEnv m -> Stms SOACS -> DistNestT m (Stms Kernels))
-> DistNestT m (Stms SOACS -> DistNestT m (Stms Kernels))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv m -> Stms SOACS -> DistNestT m (Stms Kernels)
forall (m :: * -> *).
DistEnv m -> Stms SOACS -> DistNestT m (Stms Kernels)
distOnTopLevelStms
  Stms Kernels -> DistNestT m ()
forall (m :: * -> *). Monad m => Stms Kernels -> DistNestT m ()
addKernel (Stms Kernels -> DistNestT m ())
-> DistNestT m (Stms Kernels) -> DistNestT m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> DistNestT m (Stms Kernels)
f Stms SOACS
stms

maybeDistributeStm :: MonadFreshNames m => Stm -> DistAcc -> DistNestT m DistAcc

maybeDistributeStm :: Stm -> DistAcc -> DistNestT m DistAcc
maybeDistributeStm bnd :: Stm
bnd@(Let Pattern
pat StmAux (ExpAttr SOACS)
_ (Op (Screma w form arrs))) DistAcc
acc
  | Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form =
  -- Only distribute inside the map if we can distribute everything
  -- following the map.
  DistAcc -> DistNestT m (Maybe DistAcc)
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> DistNestT m (Maybe DistAcc)
distributeIfPossible DistAcc
acc DistNestT m (Maybe DistAcc)
-> (Maybe DistAcc -> DistNestT m DistAcc) -> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe DistAcc
Nothing -> Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
bnd DistAcc
acc
    Just DistAcc
acc' -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> DistNestT m DistAcc
distribute (DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MapLoop -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
Monad m =>
MapLoop -> DistAcc -> DistNestT m DistAcc
onInnerMap (Pattern -> Certificates -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern
pat (Stm -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm
bnd) SubExp
w Lambda
lam [VName]
arrs) DistAcc
acc'

maybeDistributeStm bnd :: Stm
bnd@(Let Pattern
pat StmAux (ExpAttr SOACS)
_ (DoLoop [] [(FParam SOACS, SubExp)]
val form :: LoopForm SOACS
form@ForLoop{} Body
body)) DistAcc
acc
  | [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements PatternT Type
Pattern
pat), Body -> Bool
bodyContainsParallelism Body
body =
  DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
bnd DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostKernels
kernels, Result
res, KernelNest
nest, DistAcc
acc')
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
        Just ([Int]
perm, [PatElem]
pat_unused) <- Pattern -> Result -> Maybe ([Int], [PatElem])
permutationAndMissing Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ do
          PostKernels -> DistNestT m ()
forall (m :: * -> *). Monad m => PostKernels -> DistNestT m ()
addKernels PostKernels
kernels
          KernelNest
nest' <- [PatElem] -> KernelNest -> DistNestT m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem] -> KernelNest -> m KernelNest
expandKernelNest [PatElem]
pat_unused KernelNest
nest
          Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistNestT m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs

          Stms SOACS
bnds <- ReaderT (Scope SOACS) (DistNestT m) (Stms SOACS)
-> Scope SOACS -> DistNestT m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT
                  (KernelNest
-> SeqLoop -> ReaderT (Scope SOACS) (DistNestT m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest' ([Int]
-> Pattern
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body
-> SeqLoop
SeqLoop [Int]
perm Pattern
pat [(FParam SOACS, SubExp)]
val LoopForm SOACS
form Body
body)) Scope SOACS
types
          Stms SOACS -> DistNestT m ()
forall (m :: * -> *). Monad m => Stms SOACS -> DistNestT m ()
onTopLevelStms Stms SOACS
bnds
          DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc'
    Maybe (PostKernels, Result, KernelNest, DistAcc)
_ ->
      Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
bnd DistAcc
acc

maybeDistributeStm stm :: Stm
stm@(Let Pattern
pat StmAux (ExpAttr SOACS)
_ (If SubExp
cond Body
tbranch Body
fbranch IfAttr (BranchType SOACS)
ret)) DistAcc
acc
  | [PatElemT Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements PatternT Type
Pattern
pat),
    Body -> Bool
bodyContainsParallelism Body
tbranch Bool -> Bool -> Bool
|| Body -> Bool
bodyContainsParallelism Body
fbranch Bool -> Bool -> Bool
||
    Bool -> Bool
not ((TypeBase ExtShape NoUniqueness -> Bool)
-> [TypeBase ExtShape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase ExtShape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (IfAttr (TypeBase ExtShape NoUniqueness)
-> [TypeBase ExtShape NoUniqueness]
forall rt. IfAttr rt -> [rt]
ifReturns IfAttr (TypeBase ExtShape NoUniqueness)
IfAttr (BranchType SOACS)
ret)) =
    DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
stm DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (PostKernels
kernels, Result
res, KernelNest
nest, DistAcc
acc')
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
          (SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> IfAttr (TypeBase ExtShape NoUniqueness) -> Names
forall a. FreeIn a => a -> Names
freeIn IfAttr (TypeBase ExtShape NoUniqueness)
IfAttr (BranchType SOACS)
ret) Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
          Just ([Int]
perm, [PatElem]
pat_unused) <- Pattern -> Result -> Maybe ([Int], [PatElem])
permutationAndMissing Pattern
pat Result
res ->
            -- We need to pretend pat_unused was used anyway, by adding
            -- it to the kernel nest.
            Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElem] -> KernelNest -> DistNestT m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem] -> KernelNest -> m KernelNest
expandKernelNest [PatElem]
pat_unused KernelNest
nest
            PostKernels -> DistNestT m ()
forall (m :: * -> *). Monad m => PostKernels -> DistNestT m ()
addKernels PostKernels
kernels
            Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistNestT m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
            let branch :: Branch
branch = [Int]
-> Pattern
-> SubExp
-> Body
-> Body
-> IfAttr (BranchType SOACS)
-> Branch
Branch [Int]
perm Pattern
pat SubExp
cond Body
tbranch Body
fbranch IfAttr (BranchType SOACS)
ret
            Stms SOACS
stms <- ReaderT (Scope SOACS) (DistNestT m) (Stms SOACS)
-> Scope SOACS -> DistNestT m (Stms SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (KernelNest
-> Branch -> ReaderT (Scope SOACS) (DistNestT m) (Stms SOACS)
forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stms SOACS)
interchangeBranch KernelNest
nest' Branch
branch) Scope SOACS
types
            Stms SOACS -> DistNestT m ()
forall (m :: * -> *). Monad m => Stms SOACS -> DistNestT m ()
onTopLevelStms Stms SOACS
stms
            DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc'
      Maybe (PostKernels, Result, KernelNest, DistAcc)
_ ->
        Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
stm DistAcc
acc

maybeDistributeStm (Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc
acc
  | Just [Reduce Commutativity
comm Lambda
lam Result
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm SOACS
form,
    Just BinderT SOACS (DistNestT m) ()
m <- Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS (DistNestT m) ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern
pat SubExp
w Commutativity
comm Lambda
lam ([(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT m) ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS (DistNestT m) ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
      Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistNestT m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
      (()
_, Stms SOACS
bnds) <- BinderT SOACS (DistNestT m) ()
-> Scope SOACS -> DistNestT m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Certificates
-> BinderT SOACS (DistNestT m) () -> BinderT SOACS (DistNestT m) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs BinderT SOACS (DistNestT m) ()
m) Scope SOACS
types
      DistAcc -> Stms SOACS -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> Stms SOACS -> DistNestT m DistAcc
distributeMapBodyStms DistAcc
acc Stms SOACS
bnds

-- Parallelise segmented scatters.
maybeDistributeStm bnd :: Stm
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Scatter w lam ivs as))) DistAcc
acc =
  DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
bnd DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostKernels
kernels, Result
res, KernelNest
nest, DistAcc
acc')
      | Just ([Int]
perm, [PatElem]
pat_unused) <- Pattern -> Result -> Maybe ([Int], [PatElem])
permutationAndMissing Pattern
pat Result
res ->
        Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ do
          KernelNest
nest' <- [PatElem] -> KernelNest -> DistNestT m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem] -> KernelNest -> m KernelNest
expandKernelNest [PatElem]
pat_unused KernelNest
nest
          let lam' :: LambdaT Kernels
lam' = Lambda -> LambdaT Kernels
soacsLambdaToKernels Lambda
lam
          PostKernels -> DistNestT m ()
forall (m :: * -> *). Monad m => PostKernels -> DistNestT m ()
addKernels PostKernels
kernels
          Stms Kernels -> DistNestT m ()
forall (m :: * -> *). Monad m => Stms Kernels -> DistNestT m ()
addKernel (Stms Kernels -> DistNestT m ())
-> DistNestT m (Stms Kernels) -> DistNestT m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Pattern
-> Certificates
-> SubExp
-> LambdaT Kernels
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT m (Stms Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest
-> [Int]
-> Pattern
-> Certificates
-> SubExp
-> LambdaT Kernels
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT m (Stms Kernels)
segmentedScatterKernel KernelNest
nest' [Int]
perm Pattern
pat Certificates
cs SubExp
w LambdaT Kernels
lam' [VName]
ivs [(SubExp, Int, VName)]
as
          DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc'
    Maybe (PostKernels, Result, KernelNest, DistAcc)
_ ->
      Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
bnd DistAcc
acc

-- Parallelise segmented Hist.
maybeDistributeStm bnd :: Stm
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Hist w ops lam as))) DistAcc
acc =
  DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
bnd DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostKernels
kernels, Result
res, KernelNest
nest, DistAcc
acc')
      | Just ([Int]
perm, [PatElem]
pat_unused) <- Pattern -> Result -> Maybe ([Int], [PatElem])
permutationAndMissing Pattern
pat Result
res ->
        Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ do
          let lam' :: LambdaT Kernels
lam' = Lambda -> LambdaT Kernels
soacsLambdaToKernels Lambda
lam
          KernelNest
nest' <- [PatElem] -> KernelNest -> DistNestT m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem] -> KernelNest -> m KernelNest
expandKernelNest [PatElem]
pat_unused KernelNest
nest
          PostKernels -> DistNestT m ()
forall (m :: * -> *). Monad m => PostKernels -> DistNestT m ()
addKernels PostKernels
kernels
          Stms Kernels -> DistNestT m ()
forall (m :: * -> *). Monad m => Stms Kernels -> DistNestT m ()
addKernel (Stms Kernels -> DistNestT m ())
-> DistNestT m (Stms Kernels) -> DistNestT m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> LambdaT Kernels
-> [VName]
-> DistNestT m (Stms Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> LambdaT Kernels
-> [VName]
-> DistNestT m (Stms Kernels)
segmentedHistKernel KernelNest
nest' [Int]
perm Certificates
cs SubExp
w [HistOp SOACS]
ops LambdaT Kernels
lam' [VName]
as
          DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc'
    Maybe (PostKernels, Result, KernelNest, DistAcc)
_ ->
      Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
bnd DistAcc
acc

-- If the scan can be distributed by itself, we will turn it into a
-- segmented scan.
--
-- If the scan cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm bnd :: Stm
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc
acc
  | Just (Lambda
lam, Result
nes, Lambda
map_lam) <- ScremaForm SOACS -> Maybe (Lambda, Result, Lambda)
forall lore.
ScremaForm lore -> Maybe (Lambda lore, Result, Lambda lore)
isScanomapSOAC ScremaForm SOACS
form =
  DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
bnd DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostKernels
kernels, Result
res, KernelNest
nest, DistAcc
acc')
      | Just ([Int]
perm, [PatElem]
pat_unused) <- Pattern -> Result -> Maybe ([Int], [PatElem])
permutationAndMissing Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ do
          KernelNest
nest' <- [PatElem] -> KernelNest -> DistNestT m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem] -> KernelNest -> m KernelNest
expandKernelNest [PatElem]
pat_unused KernelNest
nest
          let map_lam' :: LambdaT Kernels
map_lam' = Lambda -> LambdaT Kernels
soacsLambdaToKernels Lambda
map_lam
              lam' :: LambdaT Kernels
lam' = Lambda -> LambdaT Kernels
soacsLambdaToKernels Lambda
lam
          Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$
            KernelNest
-> [Int]
-> SubExp
-> LambdaT Kernels
-> LambdaT Kernels
-> Result
-> [VName]
-> DistNestT m (Maybe (Stms Kernels))
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest
-> [Int]
-> SubExp
-> LambdaT Kernels
-> LambdaT Kernels
-> Result
-> [VName]
-> DistNestT m (Maybe (Stms Kernels))
segmentedScanomapKernel KernelNest
nest' [Int]
perm SubExp
w LambdaT Kernels
lam' LambdaT Kernels
map_lam' Result
nes [VName]
arrs DistNestT m (Maybe (Stms Kernels))
-> (Maybe (Stms Kernels) -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
            Certificates
-> Stm
-> DistAcc
-> PostKernels
-> DistAcc
-> Maybe (Stms Kernels)
-> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
Certificates
-> Stm
-> DistAcc
-> PostKernels
-> DistAcc
-> Maybe (Stms Kernels)
-> DistNestT m DistAcc
kernelOrNot Certificates
cs Stm
bnd DistAcc
acc PostKernels
kernels DistAcc
acc'
    Maybe (PostKernels, Result, KernelNest, DistAcc)
_ ->
      Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
bnd DistAcc
acc

-- if the reduction can be distributed by itself, we will turn it into a
-- segmented reduce.
--
-- If the reduction cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm bnd :: Stm
bnd@(Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc
acc
  | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
    Reduce Commutativity
comm Lambda
lam Result
nes <- [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds,
    Lambda -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda
map_lam Bool -> Bool -> Bool
|| Bool
incrementalFlattening =
  DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
bnd DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostKernels
kernels, Result
res, KernelNest
nest, DistAcc
acc')
      | Just ([Int]
perm, [PatElem]
pat_unused) <- Pattern -> Result -> Maybe ([Int], [PatElem])
permutationAndMissing Pattern
pat Result
res ->
          -- We need to pretend pat_unused was used anyway, by adding
          -- it to the kernel nest.
          Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ do
          KernelNest
nest' <- [PatElem] -> KernelNest -> DistNestT m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem] -> KernelNest -> m KernelNest
expandKernelNest [PatElem]
pat_unused KernelNest
nest
          let lam' :: LambdaT Kernels
lam' = Lambda -> LambdaT Kernels
soacsLambdaToKernels Lambda
lam
              map_lam' :: LambdaT Kernels
map_lam' = Lambda -> LambdaT Kernels
soacsLambdaToKernels Lambda
map_lam

          let comm' :: Commutativity
comm' | Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
lam = Commutativity
Commutative
                    | Bool
otherwise             = Commutativity
comm

          KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> LambdaT Kernels
-> Result
-> [VName]
-> DistNestT m (Maybe (Stms Kernels))
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> LambdaT Kernels
-> Result
-> [VName]
-> DistNestT m (Maybe (Stms Kernels))
regularSegmentedRedomapKernel KernelNest
nest' [Int]
perm SubExp
w Commutativity
comm' LambdaT Kernels
lam' LambdaT Kernels
map_lam' Result
nes [VName]
arrs DistNestT m (Maybe (Stms Kernels))
-> (Maybe (Stms Kernels) -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
            Certificates
-> Stm
-> DistAcc
-> PostKernels
-> DistAcc
-> Maybe (Stms Kernels)
-> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
Certificates
-> Stm
-> DistAcc
-> PostKernels
-> DistAcc
-> Maybe (Stms Kernels)
-> DistNestT m DistAcc
kernelOrNot Certificates
cs Stm
bnd DistAcc
acc PostKernels
kernels DistAcc
acc'
    Maybe (PostKernels, Result, KernelNest, DistAcc)
_ ->
      Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
bnd DistAcc
acc

maybeDistributeStm (Let Pattern
pat (StmAux Certificates
cs ExpAttr SOACS
_) (Op (Screma w form arrs))) DistAcc
acc
  | Bool
incrementalFlattening Bool -> Bool -> Bool
|| Maybe ([Reduce SOACS], Lambda) -> Bool
forall a. Maybe a -> Bool
isNothing (ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form) = do
  -- This with-loop is too complicated for us to immediately do
  -- anything, so split it up and try again.
  Scope SOACS
scope <- (Scope Kernels -> Scope SOACS) -> DistNestT m (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
  DistAcc -> Stms SOACS -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> Stms SOACS -> DistNestT m DistAcc
distributeMapBodyStms DistAcc
acc (Stms SOACS -> DistNestT m DistAcc)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> DistNestT m DistAcc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm -> Stm) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) (Stms SOACS -> Stms SOACS)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> DistNestT m DistAcc)
-> DistNestT m ((), Stms SOACS) -> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
    BinderT SOACS (DistNestT m) ()
-> Scope SOACS -> DistNestT m ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS (DistNestT m)))
-> SubExp
-> ScremaForm (Lore (BinderT SOACS (DistNestT m)))
-> [VName]
-> BinderT SOACS (DistNestT m) ()
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m), Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> ScremaForm (Lore m) -> [VName] -> m ()
dissectScrema Pattern (Lore (BinderT SOACS (DistNestT m)))
Pattern
pat SubExp
w ScremaForm (Lore (BinderT SOACS (DistNestT m)))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope

maybeDistributeStm (Let Pattern
pat StmAux (ExpAttr SOACS)
aux (BasicOp (Replicate (Shape (SubExp
d:Result
ds)) SubExp
v))) DistAcc
acc
  | [Type
t] <- PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern
pat = do
      -- XXX: We need a temporary dummy binding to prevent an empty
      -- map body.  The kernel extractor does not like empty map
      -- bodies.
      VName
tmp <- String -> DistNestT m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tmp"
      let rowt :: Type
rowt = Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
t
          newbnd :: Stm
newbnd = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpAttr SOACS)
aux (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [VName] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
d (Lambda -> ScremaForm SOACS
forall lore. Bindable lore => Lambda lore -> ScremaForm lore
mapSOAC Lambda
lam) []
          tmpbnd :: Stm
tmpbnd = Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
tmp Type
rowt]) StmAux (ExpAttr SOACS)
aux (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$
                   BasicOp SOACS -> Exp SOACS
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp SOACS -> Exp SOACS) -> BasicOp SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp SOACS
forall lore. Shape -> SubExp -> BasicOp lore
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
ds) SubExp
v
          lam :: Lambda
lam = Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaReturnType :: [Type]
lambdaReturnType = [Type
rowt]
                       , lambdaParams :: [LParam SOACS]
lambdaParams = []
                       , lambdaBody :: Body
lambdaBody = Stms SOACS -> Result -> Body
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm Stm
tmpbnd) [VName -> SubExp
Var VName
tmp]
                       }
      Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
Stm -> DistAcc -> DistNestT m DistAcc
maybeDistributeStm Stm
newbnd DistAcc
acc

maybeDistributeStm bnd :: Stm
bnd@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp Copy{})) DistAcc
acc =
  DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
distributeSingleUnaryStm DistAcc
acc Stm
bnd ((KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
 -> DistNestT m DistAcc)
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ Pattern
outerpat VName
arr ->
  Stms Kernels -> DistNestT m (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels -> DistNestT m (Stms Kernels))
-> Stms Kernels -> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels) -> Stm Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
Pattern Kernels
outerpat StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp Kernels
forall lore. VName -> BasicOp lore
Copy VName
arr

-- Opaques are applied to the full array, because otherwise they can
-- drastically inhibit parallelisation in some cases.
maybeDistributeStm bnd :: Stm
bnd@(Let (Pattern [] [PatElem
pe]) StmAux (ExpAttr SOACS)
aux (BasicOp Opaque{})) DistAcc
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> Type
forall t. Typed t => t -> Type
typeOf PatElemT Type
PatElem
pe =
      DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
distributeSingleUnaryStm DistAcc
acc Stm
bnd ((KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
 -> DistNestT m DistAcc)
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ Pattern
outerpat VName
arr ->
      Stms Kernels -> DistNestT m (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels -> DistNestT m (Stms Kernels))
-> Stms Kernels -> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels) -> Stm Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
Pattern Kernels
outerpat StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp Kernels
forall lore. VName -> BasicOp lore
Copy VName
arr

maybeDistributeStm bnd :: Stm
bnd@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Rearrange [Int]
perm VName
_))) DistAcc
acc =
  DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
distributeSingleUnaryStm DistAcc
acc Stm
bnd ((KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
 -> DistNestT m DistAcc)
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pattern
outerpat VName
arr -> do
    let r :: Int
r = [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        perm' :: [Int]
perm' = [Int
0..Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
r) [Int]
perm
    -- We need to add a copy, because the original map nest
    -- will have produced an array without aliases, and so must we.
    VName
arr' <- String -> DistNestT m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> DistNestT m VName) -> String -> DistNestT m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr
    Type
arr_t <- VName -> DistNestT m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
    Stms Kernels -> DistNestT m (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels -> DistNestT m (Stms Kernels))
-> Stms Kernels -> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
      [Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
arr' Type
arr_t]) StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp Kernels
forall lore. VName -> BasicOp lore
Copy VName
arr,
       Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
Pattern Kernels
outerpat StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp Kernels
forall lore. [Int] -> VName -> BasicOp lore
Rearrange [Int]
perm' VName
arr']

maybeDistributeStm bnd :: Stm
bnd@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Reshape ShapeChange SubExp
reshape VName
_))) DistAcc
acc =
  DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
distributeSingleUnaryStm DistAcc
acc Stm
bnd ((KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
 -> DistNestT m DistAcc)
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pattern
outerpat VName
arr -> do
    let reshape' :: ShapeChange SubExp
reshape' = (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (KernelNest -> Result
kernelNestWidths KernelNest
nest) ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++
                   (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (ShapeChange SubExp -> Result
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
reshape)
    Stms Kernels -> DistNestT m (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels -> DistNestT m (Stms Kernels))
-> Stms Kernels -> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels) -> Stm Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
Pattern Kernels
outerpat StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp Kernels
forall lore. ShapeChange SubExp -> VName -> BasicOp lore
Reshape ShapeChange SubExp
reshape' VName
arr

maybeDistributeStm stm :: Stm
stm@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Rotate Result
rots VName
_))) DistAcc
acc =
  DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
distributeSingleUnaryStm DistAcc
acc Stm
stm ((KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
 -> DistNestT m DistAcc)
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pattern
outerpat VName
arr -> do
    let rots' :: Result
rots' = (SubExp -> SubExp) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExp -> SubExp
forall a b. a -> b -> a
const (SubExp -> SubExp -> SubExp) -> SubExp -> SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) (KernelNest -> Result
kernelNestWidths KernelNest
nest) Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
rots
    Stms Kernels -> DistNestT m (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels -> DistNestT m (Stms Kernels))
-> Stms Kernels -> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels) -> Stm Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
Pattern Kernels
outerpat StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ Result -> VName -> BasicOp Kernels
forall lore. Result -> VName -> BasicOp lore
Rotate Result
rots' VName
arr

maybeDistributeStm stm :: Stm
stm@(Let Pattern
pat StmAux (ExpAttr SOACS)
aux (BasicOp (Update VName
arr Slice SubExp
slice (Var VName
v)))) DistAcc
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice =
    DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
stm DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostKernels
kernels, Result
res, KernelNest
nest, DistAcc
acc')
      | Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm
stm),
        Just ([Int]
perm, [PatElem]
pat_unused) <- Pattern -> Result -> Maybe ([Int], [PatElem])
permutationAndMissing Pattern
pat Result
res -> do
          PostKernels -> DistNestT m ()
forall (m :: * -> *). Monad m => PostKernels -> DistNestT m ()
addKernels PostKernels
kernels
          Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ do
            KernelNest
nest' <- [PatElem] -> KernelNest -> DistNestT m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem] -> KernelNest -> m KernelNest
expandKernelNest [PatElem]
pat_unused KernelNest
nest
            Stms Kernels -> DistNestT m ()
forall (m :: * -> *). Monad m => Stms Kernels -> DistNestT m ()
addKernel (Stms Kernels -> DistNestT m ())
-> DistNestT m (Stms Kernels) -> DistNestT m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
              KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT m (Stms Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT m (Stms Kernels)
segmentedUpdateKernel KernelNest
nest' [Int]
perm (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) VName
arr Slice SubExp
slice VName
v
            DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc'

    Maybe (PostKernels, Result, KernelNest, DistAcc)
_ -> Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
stm DistAcc
acc

-- XXX?  This rule is present to avoid the case where an in-place
-- update is distributed as its own kernel, as this would mean thread
-- then writes the entire array that it updated.  This is problematic
-- because the in-place updates is O(1), but writing the array is
-- O(n).  It is OK if the in-place update is preceded, followed, or
-- nested inside a sequential loop or similar, because that will
-- probably be O(n) by itself.  As a hack, we only distribute if there
-- does not appear to be a loop following.  The better solution is to
-- depend on memory block merging for this optimisation, but it is not
-- ready yet.
maybeDistributeStm (Let Pattern
pat StmAux (ExpAttr SOACS)
aux (BasicOp (Update VName
arr [DimFix SubExp
i] SubExp
v))) DistAcc
acc
  | [Type
t] <- PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern
pat,
    Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm Kernels -> Bool) -> Stms Kernels -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp Kernels -> Bool
forall lore. ExpT lore -> Bool
amortises (Exp Kernels -> Bool)
-> (Stm Kernels -> Exp Kernels) -> Stm Kernels -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> Exp Kernels
forall lore. Stm lore -> Exp lore
stmExp) (Stms Kernels -> Bool) -> Stms Kernels -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc -> Stms Kernels
distStms DistAcc
acc = do
      let w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t
          et :: Type
et = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1 Type
t
          lam :: Lambda
lam = Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaParams :: [LParam SOACS]
lambdaParams = []
                       , lambdaReturnType :: [Type]
lambdaReturnType = [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32, Type
et]
                       , lambdaBody :: Body
lambdaBody = Stms SOACS -> Result -> Body
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms SOACS
forall a. Monoid a => a
mempty [SubExp
i, SubExp
v] }
      Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
Stm -> DistAcc -> DistNestT m DistAcc
maybeDistributeStm (Pattern -> StmAux (ExpAttr SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpAttr SOACS)
aux (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda -> [VName] -> [(SubExp, Int, VName)] -> SOAC SOACS
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(SubExp, Int, VName)] -> SOAC lore
Scatter (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) Lambda
lam [] [(SubExp
w, Int
1, VName
arr)]) DistAcc
acc
  where amortises :: ExpT lore -> Bool
amortises DoLoop{} = Bool
True
        amortises Op{} = Bool
True
        amortises ExpT lore
_ = Bool
False

maybeDistributeStm stm :: Stm
stm@(Let Pattern
_ StmAux (ExpAttr SOACS)
aux (BasicOp (Concat Int
d VName
x [VName]
xs SubExp
w))) DistAcc
acc =
  DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
stm DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostKernels
kernels, Result
_, KernelNest
nest, DistAcc
acc') ->
      Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$
      KernelNest -> DistNestT m (Maybe (Stms Kernels))
segmentedConcat KernelNest
nest DistNestT m (Maybe (Stms Kernels))
-> (Maybe (Stms Kernels) -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
      Certificates
-> Stm
-> DistAcc
-> PostKernels
-> DistAcc
-> Maybe (Stms Kernels)
-> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
Certificates
-> Stm
-> DistAcc
-> PostKernels
-> DistAcc
-> Maybe (Stms Kernels)
-> DistNestT m DistAcc
kernelOrNot (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) Stm
stm DistAcc
acc PostKernels
kernels DistAcc
acc'
    Maybe (PostKernels, Result, KernelNest, DistAcc)
_ ->
      Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
stm DistAcc
acc

  where segmentedConcat :: KernelNest -> DistNestT m (Maybe (Stms Kernels))
segmentedConcat KernelNest
nest =
          KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
isSegmentedOp KernelNest
nest [Int
0] SubExp
w Names
forall a. Monoid a => a
mempty Names
forall a. Monoid a => a
mempty [] (VName
xVName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:[VName]
xs) ((Pattern
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> Result
  -> [VName]
  -> [VName]
  -> BinderT Kernels m ())
 -> DistNestT m (Maybe (Stms Kernels)))
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
forall a b. (a -> b) -> a -> b
$
          \Pattern
pat [(VName, SubExp)]
_ [KernelInput]
_ Result
_ (VName
x':[VName]
xs') [VName]
_ ->
            let d' :: Int
d' = Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            in Stm (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (BinderT Kernels m)) -> BinderT Kernels m ())
-> Stm (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
Pattern Kernels
pat StmAux (ExpAttr SOACS)
StmAux (ExpAttr Kernels)
aux (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp Kernels
forall lore. Int -> VName -> [VName] -> SubExp -> BasicOp lore
Concat Int
d' VName
x' [VName]
xs' SubExp
w

maybeDistributeStm Stm
bnd DistAcc
acc =
  Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
bnd DistAcc
acc

distributeSingleUnaryStm :: MonadFreshNames m =>
                            DistAcc -> Stm
                         -> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Out.Kernels))
                         -> DistNestT m DistAcc
distributeSingleUnaryStm :: DistAcc
-> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels))
-> DistNestT m DistAcc
distributeSingleUnaryStm DistAcc
acc Stm
bnd KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels)
f =
  DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
bnd DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
-> (Maybe (PostKernels, Result, KernelNest, DistAcc)
    -> DistNestT m DistAcc)
-> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostKernels
kernels, Result
res, KernelNest
nest, DistAcc
acc')
      | Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm
bnd),
        (LoopNesting
outer, [LoopNesting]
inners) <- KernelNest
nest,
        [(Param Type
arr_p, VName
arr)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        KernelNest -> Names
boundInKernelNest KernelNest
nest Names -> Names -> Names
`namesIntersection` Stm -> Names
forall a. FreeIn a => a -> Names
freeIn Stm
bnd
        Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Names
oneName (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
arr_p) -> do
          PostKernels -> DistNestT m ()
forall (m :: * -> *). Monad m => PostKernels -> DistNestT m ()
addKernels PostKernels
kernels
          let outerpat :: Pattern Kernels
outerpat = LoopNesting -> Pattern Kernels
loopNestingPattern (LoopNesting -> Pattern Kernels) -> LoopNesting -> Pattern Kernels
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
          Scope Kernels -> DistNestT m DistAcc -> DistNestT m DistAcc
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (DistAcc -> Scope Kernels
typeEnvFromDistAcc DistAcc
acc') (DistNestT m DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall a b. (a -> b) -> a -> b
$ do
            (VName
arr', Stms Kernels
pre_stms) <- VName -> [LoopNesting] -> DistNestT m (VName, Stms Kernels)
forall (m :: * -> *) lore lore.
(HasScope lore m, MonadFreshNames m, LetAttr lore ~ Type,
 ExpAttr lore ~ ()) =>
VName -> [LoopNesting] -> m (VName, Stms lore)
repeatMissing VName
arr (LoopNesting
outerLoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
:[LoopNesting]
inners)
            Stms Kernels
f_stms <- Stms Kernels
-> DistNestT m (Stms Kernels) -> DistNestT m (Stms Kernels)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
pre_stms (DistNestT m (Stms Kernels) -> DistNestT m (Stms Kernels))
-> DistNestT m (Stms Kernels) -> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ KernelNest -> Pattern -> VName -> DistNestT m (Stms Kernels)
f KernelNest
nest Pattern
Pattern Kernels
outerpat VName
arr'
            Stms Kernels -> DistNestT m ()
forall (m :: * -> *). Monad m => Stms Kernels -> DistNestT m ()
addKernel (Stms Kernels -> DistNestT m ()) -> Stms Kernels -> DistNestT m ()
forall a b. (a -> b) -> a -> b
$ Stms Kernels
pre_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
f_stms
            DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc'
    Maybe (PostKernels, Result, KernelNest, DistAcc)
_ -> Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel Stm
bnd DistAcc
acc
  where -- | For an imperfectly mapped array, repeat the missing
        -- dimensions to make it look like it was in fact perfectly
        -- mapped.
        repeatMissing :: VName -> [LoopNesting] -> m (VName, Stms lore)
repeatMissing VName
arr [LoopNesting]
inners = do
          Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          let shapes :: [Shape]
shapes = VName -> Type -> [LoopNesting] -> [Shape]
forall shape u.
ArrayShape shape =>
VName -> TypeBase shape u -> [LoopNesting] -> [Shape]
determineRepeats VName
arr Type
arr_t [LoopNesting]
inners
          if (Shape -> Bool) -> [Shape] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
==Result -> Shape
forall d. [d] -> ShapeBase d
Shape []) [Shape]
shapes then (VName, Stms lore) -> m (VName, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr, Stms lore
forall a. Monoid a => a
mempty)
            else do
            let ([Shape]
outer_shapes, Shape
inner_shape) = [Shape] -> Type -> ([Shape], Shape)
repeatShapes [Shape]
shapes Type
arr_t
                arr_t' :: Type
arr_t' = [Shape] -> Shape -> Type -> Type
repeatDims [Shape]
outer_shapes Shape
inner_shape Type
arr_t
            VName
arr' <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr
            (VName, Stms lore) -> m (VName, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arr', Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm (Stm lore -> Stms lore) -> Stm lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName -> Type -> PatElemT Type
forall attr. VName -> attr -> PatElemT attr
PatElem VName
arr' Type
arr_t']) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                          BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Shape] -> Shape -> VName -> BasicOp lore
forall lore. [Shape] -> Shape -> VName -> BasicOp lore
Repeat [Shape]
outer_shapes Shape
inner_shape VName
arr)

        determineRepeats :: VName -> TypeBase shape u -> [LoopNesting] -> [Shape]
determineRepeats VName
arr TypeBase shape u
arr_t [LoopNesting]
nests
          | ([LoopNesting]
skipped, LoopNesting
arr_nest:[LoopNesting]
nests') <- (LoopNesting -> Bool)
-> [LoopNesting] -> ([LoopNesting], [LoopNesting])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (VName -> LoopNesting -> Bool
hasInput VName
arr) [LoopNesting]
nests,
            [(Param Type
arr_p, VName
_)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
arr_nest =
              Result -> Shape
forall d. [d] -> ShapeBase d
Shape ((LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
skipped) Shape -> [Shape] -> [Shape]
forall a. a -> [a] -> [a]
:
              VName -> TypeBase shape u -> [LoopNesting] -> [Shape]
determineRepeats (Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
arr_p) (TypeBase shape u -> TypeBase shape u
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase shape u
arr_t) [LoopNesting]
nests'
          | Bool
otherwise =
              Result -> Shape
forall d. [d] -> ShapeBase d
Shape ((LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
nests) Shape -> [Shape] -> [Shape]
forall a. a -> [a] -> [a]
: Int -> Shape -> [Shape]
forall a. Int -> a -> [a]
replicate (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
arr_t) (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [])

        hasInput :: VName -> LoopNesting -> Bool
hasInput VName
arr LoopNesting
nest
          | [(Param Type
_, VName
arr')] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
nest, VName
arr' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr = Bool
True
          | Bool
otherwise = Bool
False


distribute :: MonadFreshNames m => DistAcc -> DistNestT m DistAcc
distribute :: DistAcc -> DistNestT m DistAcc
distribute DistAcc
acc =
  DistAcc -> Maybe DistAcc -> DistAcc
forall a. a -> Maybe a -> a
fromMaybe DistAcc
acc (Maybe DistAcc -> DistAcc)
-> DistNestT m (Maybe DistAcc) -> DistNestT m DistAcc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistAcc -> DistNestT m (Maybe DistAcc)
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> DistNestT m (Maybe DistAcc)
distributeIfPossible DistAcc
acc

mkSegLevel :: MonadFreshNames m => DistNestT m (MkSegLevel (DistNestT m))
mkSegLevel :: DistNestT m (MkSegLevel (DistNestT m))
mkSegLevel = do
  MkSegLevel m
mk_lvl <- (DistEnv m -> MkSegLevel m) -> DistNestT m (MkSegLevel m)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv m -> MkSegLevel m
forall (m :: * -> *). DistEnv m -> MkSegLevel m
distSegLevel
  MkSegLevel (DistNestT m) -> DistNestT m (MkSegLevel (DistNestT m))
forall (m :: * -> *) a. Monad m => a -> m a
return (MkSegLevel (DistNestT m)
 -> DistNestT m (MkSegLevel (DistNestT m)))
-> MkSegLevel (DistNestT m)
-> DistNestT m (MkSegLevel (DistNestT m))
forall a b. (a -> b) -> a -> b
$ \Result
w String
desc ThreadRecommendation
r -> do
    Scope Kernels
scope <- BinderT Kernels (DistNestT m) (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
    (SegLevel
lvl, Stms Kernels
stms) <- DistNestT m (SegLevel, Stms Kernels)
-> BinderT Kernels (DistNestT m) (SegLevel, Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT m (SegLevel, Stms Kernels)
 -> BinderT Kernels (DistNestT m) (SegLevel, Stms Kernels))
-> DistNestT m (SegLevel, Stms Kernels)
-> BinderT Kernels (DistNestT m) (SegLevel, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ m (SegLevel, Stms Kernels) -> DistNestT m (SegLevel, Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SegLevel, Stms Kernels)
 -> DistNestT m (SegLevel, Stms Kernels))
-> m (SegLevel, Stms Kernels)
-> DistNestT m (SegLevel, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ BinderT Kernels m SegLevel
-> Scope Kernels -> m (SegLevel, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (MkSegLevel m
mk_lvl Result
w String
desc ThreadRecommendation
r) Scope Kernels
scope
    Stms (Lore (BinderT Kernels (DistNestT m)))
-> BinderT Kernels (DistNestT m) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (DistNestT m)))
Stms Kernels
stms
    SegLevel -> BinderT Kernels (DistNestT m) SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return SegLevel
lvl

distributeIfPossible :: MonadFreshNames m => DistAcc -> DistNestT m (Maybe DistAcc)
distributeIfPossible :: DistAcc -> DistNestT m (Maybe DistAcc)
distributeIfPossible DistAcc
acc = do
  Nestings
nest <- (DistEnv m -> Nestings) -> DistNestT m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv m -> Nestings
forall (m :: * -> *). DistEnv m -> Nestings
distNest
  MkSegLevel (DistNestT m)
mk_lvl <- DistNestT m (MkSegLevel (DistNestT m))
forall (m :: * -> *).
MonadFreshNames m =>
DistNestT m (MkSegLevel (DistNestT m))
mkSegLevel
  MkSegLevel (DistNestT m)
-> Nestings
-> Targets
-> Stms Kernels
-> DistNestT m (Maybe (Targets, Stms Kernels))
forall (m :: * -> *).
(MonadFreshNames m, LocalScope Kernels m, MonadLogger m) =>
MkSegLevel m
-> Nestings
-> Targets
-> Stms Kernels
-> m (Maybe (Targets, Stms Kernels))
tryDistribute MkSegLevel (DistNestT m)
mk_lvl Nestings
nest (DistAcc -> Targets
distTargets DistAcc
acc) (DistAcc -> Stms Kernels
distStms DistAcc
acc) DistNestT m (Maybe (Targets, Stms Kernels))
-> (Maybe (Targets, Stms Kernels) -> DistNestT m (Maybe DistAcc))
-> DistNestT m (Maybe DistAcc)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms Kernels)
Nothing -> Maybe DistAcc -> DistNestT m (Maybe DistAcc)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe DistAcc
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms Kernels
kernel) -> do
      Stms Kernels -> DistNestT m ()
forall (m :: * -> *). Monad m => Stms Kernels -> DistNestT m ()
addKernel Stms Kernels
kernel
      Maybe DistAcc -> DistNestT m (Maybe DistAcc)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe DistAcc -> DistNestT m (Maybe DistAcc))
-> Maybe DistAcc -> DistNestT m (Maybe DistAcc)
forall a b. (a -> b) -> a -> b
$ DistAcc -> Maybe DistAcc
forall a. a -> Maybe a
Just DistAcc :: Targets -> Stms Kernels -> DistAcc
DistAcc { distTargets :: Targets
distTargets = Targets
targets
                            , distStms :: Stms Kernels
distStms = Stms Kernels
forall a. Monoid a => a
mempty
                            }

distributeSingleStm :: MonadFreshNames m =>
                       DistAcc -> Stm
                    -> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm :: DistAcc
-> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm DistAcc
acc Stm
bnd = do
  Nestings
nest <- (DistEnv m -> Nestings) -> DistNestT m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv m -> Nestings
forall (m :: * -> *). DistEnv m -> Nestings
distNest
  MkSegLevel (DistNestT m)
mk_lvl <- DistNestT m (MkSegLevel (DistNestT m))
forall (m :: * -> *).
MonadFreshNames m =>
DistNestT m (MkSegLevel (DistNestT m))
mkSegLevel
  MkSegLevel (DistNestT m)
-> Nestings
-> Targets
-> Stms Kernels
-> DistNestT m (Maybe (Targets, Stms Kernels))
forall (m :: * -> *).
(MonadFreshNames m, LocalScope Kernels m, MonadLogger m) =>
MkSegLevel m
-> Nestings
-> Targets
-> Stms Kernels
-> m (Maybe (Targets, Stms Kernels))
tryDistribute MkSegLevel (DistNestT m)
mk_lvl Nestings
nest (DistAcc -> Targets
distTargets DistAcc
acc) (DistAcc -> Stms Kernels
distStms DistAcc
acc) DistNestT m (Maybe (Targets, Stms Kernels))
-> (Maybe (Targets, Stms Kernels)
    -> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc)))
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Targets, Stms Kernels)
Nothing -> Maybe (PostKernels, Result, KernelNest, DistAcc)
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostKernels, Result, KernelNest, DistAcc)
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms Kernels
distributed_bnds) ->
      Nestings
-> Targets
-> Stm
-> DistNestT m (Maybe (Result, Targets, KernelNest))
forall (m :: * -> *) t lore.
(MonadFreshNames m, HasScope t m, Attributes lore) =>
Nestings
-> Targets -> Stm lore -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm
bnd DistNestT m (Maybe (Result, Targets, KernelNest))
-> (Maybe (Result, Targets, KernelNest)
    -> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc)))
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe (Result, Targets, KernelNest)
Nothing -> Maybe (PostKernels, Result, KernelNest, DistAcc)
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (PostKernels, Result, KernelNest, DistAcc)
forall a. Maybe a
Nothing
        Just (Result
res, Targets
targets', KernelNest
new_kernel_nest) ->
          Maybe (PostKernels, Result, KernelNest, DistAcc)
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (PostKernels, Result, KernelNest, DistAcc)
 -> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc)))
-> Maybe (PostKernels, Result, KernelNest, DistAcc)
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
forall a b. (a -> b) -> a -> b
$ (PostKernels, Result, KernelNest, DistAcc)
-> Maybe (PostKernels, Result, KernelNest, DistAcc)
forall a. a -> Maybe a
Just ([PostKernel] -> PostKernels
PostKernels [Stms Kernels -> PostKernel
PostKernel Stms Kernels
distributed_bnds],
                         Result
res,
                         KernelNest
new_kernel_nest,
                         DistAcc :: Targets -> Stms Kernels -> DistAcc
DistAcc { distTargets :: Targets
distTargets = Targets
targets'
                                 , distStms :: Stms Kernels
distStms = Stms Kernels
forall a. Monoid a => a
mempty
                                 })

segmentedScatterKernel :: MonadFreshNames m =>
                          KernelNest
                       -> [Int]
                       -> Pattern
                       -> Certificates
                       -> SubExp
                       -> Out.Lambda Out.Kernels
                       -> [VName] -> [(SubExp,Int,VName)]
                       -> DistNestT m KernelsStms
segmentedScatterKernel :: KernelNest
-> [Int]
-> Pattern
-> Certificates
-> SubExp
-> LambdaT Kernels
-> [VName]
-> [(SubExp, Int, VName)]
-> DistNestT m (Stms Kernels)
segmentedScatterKernel KernelNest
nest [Int]
perm Pattern
scatter_pat Certificates
cs SubExp
scatter_w LambdaT Kernels
lam [VName]
ivs [(SubExp, Int, VName)]
dests = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a scatter is not a reduction or
  -- scan.
  --
  -- First, pretend that the scatter is also part of the nesting.  The
  -- KernelNest we produce here is technically not sensible, but it's
  -- good enough for flatKernel to work.
  let nest' :: KernelNest
nest' = Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (Pattern
Pattern Kernels
scatter_pat, BodyT Kernels -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT Kernels -> Result) -> BodyT Kernels -> Result
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam)
              (Pattern Kernels
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pattern
Pattern Kernels
scatter_pat Certificates
cs SubExp
scatter_w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT Kernels
lam) [VName]
ivs) KernelNest
nest
  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest'

  let (Result
as_ws, [Int]
as_ns, [VName]
as) = [(SubExp, Int, VName)] -> (Result, [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
dests

  -- The input/output arrays ('as') _must_ correspond to some kernel
  -- input, or else the original nested scatter would have been
  -- ill-typed.  Find them.
  [KernelInput]
as_inps <- (VName -> DistNestT m KernelInput)
-> [VName] -> DistNestT m [KernelInput]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([KernelInput] -> VName -> DistNestT m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
kernel_inps) [VName]
as

  MkSegLevel (DistNestT m)
mk_lvl <- DistNestT m (MkSegLevel (DistNestT m))
forall (m :: * -> *).
MonadFreshNames m =>
DistNestT m (MkSegLevel (DistNestT m))
mkSegLevel

  let rts :: [Type]
rts = ([Type] -> [Type]) -> [[Type]] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
1) ([[Type]] -> [Type]) -> [[Type]] -> [Type]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$
            Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT Kernels
lam
      (Result
is,Result
vs) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns) (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT Kernels -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT Kernels -> Result) -> BodyT Kernels -> Result
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam

  -- Maybe add certificates to the indices.
  (Result
is', Stms Kernels
k_body_stms) <- Binder Kernels Result -> DistNestT m (Result, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels Result -> DistNestT m (Result, Stms Kernels))
-> Binder Kernels Result -> DistNestT m (Result, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT Kernels -> Stms Kernels) -> BodyT Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam
    Result
-> (SubExp -> BinderT Kernels (State VNameSource) SubExp)
-> Binder Kernels Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
is ((SubExp -> BinderT Kernels (State VNameSource) SubExp)
 -> Binder Kernels Result)
-> (SubExp -> BinderT Kernels (State VNameSource) SubExp)
-> Binder Kernels Result
forall a b. (a -> b) -> a -> b
$ \SubExp
i ->
      if Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty
      then SubExp -> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
i
      else Certificates
-> BinderT Kernels (State VNameSource) SubExp
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT Kernels (State VNameSource) SubExp
 -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) SubExp
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"scatter_i" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp Kernels
forall lore. SubExp -> BasicOp lore
SubExp SubExp
i

  let k_body :: KernelBody Kernels
k_body = BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
k_body_stms ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$
               ((SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult)
-> [(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)]
-> (SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace) ([(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult])
-> [(SubExp, KernelInput, [(SubExp, SubExp)])] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$
               Result
-> [KernelInput]
-> [[(SubExp, SubExp)]]
-> [(SubExp, KernelInput, [(SubExp, SubExp)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
as_ws [KernelInput]
as_inps ([[(SubExp, SubExp)]]
 -> [(SubExp, KernelInput, [(SubExp, SubExp)])])
-> [[(SubExp, SubExp)]]
-> [(SubExp, KernelInput, [(SubExp, SubExp)])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns ([(SubExp, SubExp)] -> [[(SubExp, SubExp)]])
-> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a b. (a -> b) -> a -> b
$ Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
is' Result
vs

  (SegOp Kernels
k, Stms Kernels
k_bnds) <- MkSegLevel (DistNestT m)
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody Kernels
-> DistNestT m (SegOp Kernels, Stms Kernels)
forall (m :: * -> *).
(HasScope Kernels m, MonadFreshNames m) =>
MkSegLevel m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody Kernels
-> m (SegOp Kernels, Stms Kernels)
mapKernel MkSegLevel (DistNestT m)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type]
rts KernelBody Kernels
k_body

  (Stm Kernels -> DistNestT m (Stm Kernels))
-> Stms Kernels -> DistNestT m (Stms Kernels)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm Kernels -> DistNestT m (Stm Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms Kernels -> DistNestT m (Stms Kernels))
-> (BinderT Kernels (State VNameSource) ()
    -> DistNestT m (Stms Kernels))
-> BinderT Kernels (State VNameSource) ()
-> DistNestT m (Stms Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinderT Kernels (State VNameSource) ()
-> DistNestT m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) ()
 -> DistNestT m (Stms Kernels))
-> BinderT Kernels (State VNameSource) ()
-> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
k_bnds

    let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
              PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pattern Kernels
loopNestingPattern (LoopNesting -> Pattern Kernels) -> LoopNesting -> Pattern Kernels
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ PatternT Type
Pattern (Lore (BinderT Kernels (State VNameSource)))
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> Exp Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> Exp Kernels) -> Op Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp SegOp Kernels
k
  where findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
          m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
        bad :: a
bad = String -> a
forall a. HasCallStack => String -> a
error String
"Ill-typed nested scatter encountered."

        inPlaceReturn :: [(VName, SubExp)]
-> (SubExp, KernelInput, [(SubExp, SubExp)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace (SubExp
aw, KernelInput
inp, [(SubExp, SubExp)]
is_vs) =
          Result -> VName -> [(Result, SubExp)] -> KernelResult
WriteReturns (Result -> Result
forall a. [a] -> [a]
init Result
wsResult -> Result -> Result
forall a. [a] -> [a] -> [a]
++[SubExp
aw]) (KernelInput -> VName
kernelInputArray KernelInput
inp)
          [ ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
gtids)Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++[SubExp
i], SubExp
v) | (SubExp
i,SubExp
v) <- [(SubExp, SubExp)]
is_vs ]
          where ([VName]
gtids,Result
ws) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
ispace

segmentedUpdateKernel :: MonadFreshNames m =>
                         KernelNest
                      -> [Int]
                      -> Certificates
                      -> VName
                      -> Slice SubExp
                      -> VName
                      -> DistNestT m KernelsStms
segmentedUpdateKernel :: KernelNest
-> [Int]
-> Certificates
-> VName
-> Slice SubExp
-> VName
-> DistNestT m (Stms Kernels)
segmentedUpdateKernel KernelNest
nest [Int]
perm Certificates
cs VName
arr Slice SubExp
slice VName
v = do
  ([(VName, SubExp)]
base_ispace, [KernelInput]
kernel_inps) <- KernelNest -> DistNestT m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let slice_dims :: Result
slice_dims = Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  [VName]
slice_gtids <- Int -> DistNestT m VName -> DistNestT m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
slice_dims) (String -> DistNestT m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid_slice")

  let ispace :: [(VName, SubExp)]
ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids Result
slice_dims

  ((Type
res_t, KernelResult
res), Stms Kernels
kstms) <- Binder Kernels (Type, KernelResult)
-> DistNestT m ((Type, KernelResult), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (Type, KernelResult)
 -> DistNestT m ((Type, KernelResult), Stms Kernels))
-> Binder Kernels (Type, KernelResult)
-> DistNestT m ((Type, KernelResult), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do

    -- Compute indexes into full array.
    SubExp
v' <- Certificates
-> BinderT Kernels (State VNameSource) SubExp
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT Kernels (State VNameSource) SubExp
 -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) SubExp
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
          String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"v" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp Kernels
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
v (Slice SubExp -> BasicOp Kernels)
-> Slice SubExp -> BasicOp Kernels
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids
    let pexp :: SubExp -> PrimExp VName
pexp = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32
    Result
slice_is <- (PrimExp VName -> BinderT Kernels (State VNameSource) SubExp)
-> [PrimExp VName] -> Binder Kernels Result
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index" (Exp Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> (PrimExp VName
    -> BinderT Kernels (State VNameSource) (Exp Kernels))
-> PrimExp VName
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> BinderT Kernels (State VNameSource) (Exp Kernels)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp) ([PrimExp VName] -> Binder Kernels Result)
-> [PrimExp VName] -> Binder Kernels Result
forall a b. (a -> b) -> a -> b
$
                Slice (PrimExp VName) -> [PrimExp VName] -> [PrimExp VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((DimIndex SubExp -> DimIndex (PrimExp VName))
-> Slice SubExp -> Slice (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> PrimExp VName)
-> DimIndex SubExp -> DimIndex (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> PrimExp VName
pexp) Slice SubExp
slice) ([PrimExp VName] -> [PrimExp VName])
-> [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ (VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> PrimExp VName
pexp (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
slice_gtids

    let write_is :: Result
write_is = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) [(VName, SubExp)]
base_ispace Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
slice_is
        arr' :: VName
arr' = VName -> (KernelInput -> VName) -> Maybe KernelInput -> VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> VName
forall a. HasCallStack => String -> a
error String
"incorrectly typed Update") KernelInput -> VName
kernelInputArray (Maybe KernelInput -> VName) -> Maybe KernelInput -> VName
forall a b. (a -> b) -> a -> b
$
               (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
    Type
arr_t <- VName -> BinderT Kernels (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr'
    Type
v_t <- SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v'
    (Type, KernelResult) -> Binder Kernels (Type, KernelResult)
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
v_t,
            Result -> VName -> [(Result, SubExp)] -> KernelResult
WriteReturns (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t) VName
arr' [(Result
write_is, SubExp
v')])

  MkSegLevel (DistNestT m)
mk_lvl <- DistNestT m (MkSegLevel (DistNestT m))
forall (m :: * -> *).
MonadFreshNames m =>
DistNestT m (MkSegLevel (DistNestT m))
mkSegLevel
  (SegOp Kernels
k, Stms Kernels
prestms) <- MkSegLevel (DistNestT m)
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody Kernels
-> DistNestT m (SegOp Kernels, Stms Kernels)
forall (m :: * -> *).
(HasScope Kernels m, MonadFreshNames m) =>
MkSegLevel m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody Kernels
-> m (SegOp Kernels, Stms Kernels)
mapKernel MkSegLevel (DistNestT m)
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
kernel_inps [Type
res_t] (KernelBody Kernels -> DistNestT m (SegOp Kernels, Stms Kernels))
-> KernelBody Kernels -> DistNestT m (SegOp Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$
                  BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
kstms [KernelResult
res]

  (Stm Kernels -> DistNestT m (Stm Kernels))
-> Stms Kernels -> DistNestT m (Stms Kernels)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm Kernels -> DistNestT m (Stm Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms Kernels -> DistNestT m (Stms Kernels))
-> (BinderT Kernels (State VNameSource) ()
    -> DistNestT m (Stms Kernels))
-> BinderT Kernels (State VNameSource) ()
-> DistNestT m (Stms Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinderT Kernels (State VNameSource) ()
-> DistNestT m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) ()
 -> DistNestT m (Stms Kernels))
-> BinderT Kernels (State VNameSource) ()
-> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
prestms

    let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
              PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pattern Kernels
loopNestingPattern (LoopNesting -> Pattern Kernels) -> LoopNesting -> Pattern Kernels
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ PatternT Type
Pattern (Lore (BinderT Kernels (State VNameSource)))
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> Exp Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> Exp Kernels) -> Op Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp SegOp Kernels
k

segmentedHistKernel :: MonadFreshNames m =>
                            KernelNest
                         -> [Int]
                         -> Certificates
                         -> SubExp
                         -> [SOAC.HistOp SOACS]
                         -> Out.Lambda Out.Kernels
                         -> [VName]
                         -> DistNestT m KernelsStms
segmentedHistKernel :: KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> LambdaT Kernels
-> [VName]
-> DistNestT m (Stms Kernels)
segmentedHistKernel KernelNest
nest [Int]
perm Certificates
cs SubExp
hist_w [HistOp SOACS]
ops LambdaT Kernels
lam [VName]
arrs = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a Hist is not a reduction or
  -- scan.
  ([(VName, SubExp)]
ispace, [KernelInput]
inputs) <- KernelNest -> DistNestT m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let orig_pat :: PatternT Type
orig_pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
                 PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pattern Kernels
loopNestingPattern (LoopNesting -> Pattern Kernels) -> LoopNesting -> Pattern Kernels
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

  -- The input/output arrays _must_ correspond to some kernel input,
  -- or else the original nested Hist would have been ill-typed.
  -- Find them.
  [HistOp SOACS]
ops' <- [HistOp SOACS]
-> (HistOp SOACS -> DistNestT m (HistOp SOACS))
-> DistNestT m [HistOp SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> DistNestT m (HistOp SOACS))
 -> DistNestT m [HistOp SOACS])
-> (HistOp SOACS -> DistNestT m (HistOp SOACS))
-> DistNestT m [HistOp SOACS]
forall a b. (a -> b) -> a -> b
$ \(SOAC.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) ->
    SubExp -> SubExp -> [VName] -> Result -> Lambda -> HistOp SOACS
forall lore.
SubExp -> SubExp -> [VName] -> Result -> Lambda lore -> HistOp lore
SOAC.HistOp SubExp
num_bins SubExp
rf
    ([VName] -> Result -> Lambda -> HistOp SOACS)
-> DistNestT m [VName]
-> DistNestT m (Result -> Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> DistNestT m VName) -> [VName] -> DistNestT m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((KernelInput -> VName)
-> DistNestT m KernelInput -> DistNestT m VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KernelInput -> VName
kernelInputArray (DistNestT m KernelInput -> DistNestT m VName)
-> (VName -> DistNestT m KernelInput) -> VName -> DistNestT m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [KernelInput] -> VName -> DistNestT m KernelInput
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t KernelInput -> VName -> m KernelInput
findInput [KernelInput]
inputs) [VName]
dests
    DistNestT m (Result -> Lambda -> HistOp SOACS)
-> DistNestT m Result -> DistNestT m (Lambda -> HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistNestT m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
nes
    DistNestT m (Lambda -> HistOp SOACS)
-> DistNestT m Lambda -> DistNestT m (HistOp SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda -> DistNestT m Lambda
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda
op

  MkSegLevel m
mk_lvl <- (DistEnv m -> MkSegLevel m) -> DistNestT m (MkSegLevel m)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv m -> MkSegLevel m
forall (m :: * -> *). DistEnv m -> MkSegLevel m
distSegLevel
  Scope Kernels
scope <- DistNestT m (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  m (Stms Kernels) -> DistNestT m (Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms Kernels) -> DistNestT m (Stms Kernels))
-> m (Stms Kernels) -> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (BinderT Kernels m () -> Scope Kernels -> m (Stms Kernels))
-> Scope Kernels -> BinderT Kernels m () -> m (Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT Kernels m () -> Scope Kernels -> m (Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (Stms lore)
runBinderT_ Scope Kernels
scope (BinderT Kernels m () -> m (Stms Kernels))
-> BinderT Kernels m () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    -- It is important not to launch unnecessarily many threads for
    -- histograms, because it may mean we unnecessarily need to reduce
    -- subhistograms as well.
    SegLevel
lvl <- MkSegLevel m
mk_lvl (SubExp
hist_w SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"seghist" (ThreadRecommendation -> BinderT Kernels m SegLevel)
-> ThreadRecommendation -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
    Stms Kernels -> BinderT Kernels m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> BinderT Kernels m ())
-> BinderT Kernels m (Stms Kernels) -> BinderT Kernels m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      SegLevel
-> Pattern
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> LambdaT Kernels
-> [VName]
-> BinderT Kernels m (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> LambdaT Kernels
-> [VName]
-> m (Stms Kernels)
histKernel SegLevel
lvl PatternT Type
Pattern
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops' LambdaT Kernels
lam [VName]
arrs
  where findInput :: t KernelInput -> VName -> m KernelInput
findInput t KernelInput
kernel_inps VName
a =
          m KernelInput
-> (KernelInput -> m KernelInput)
-> Maybe KernelInput
-> m KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m KernelInput
forall a. a
bad KernelInput -> m KernelInput
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelInput -> m KernelInput)
-> Maybe KernelInput -> m KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
        bad :: a
bad = String -> a
forall a. HasCallStack => String -> a
error String
"Ill-typed nested Hist encountered."

histKernel :: (MonadFreshNames m, HasScope Out.Kernels m) =>
                   SegLevel -> Pattern -> [(VName, SubExp)] -> [KernelInput]
                -> Certificates -> SubExp -> [SOAC.HistOp SOACS]
                -> Out.Lambda Out.Kernels -> [VName]
                -> m KernelsStms
histKernel :: SegLevel
-> Pattern
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> LambdaT Kernels
-> [VName]
-> m (Stms Kernels)
histKernel SegLevel
lvl Pattern
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certificates
cs SubExp
hist_w [HistOp SOACS]
ops LambdaT Kernels
lam [VName]
arrs =
  BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) () -> m (Stms Kernels))
-> BinderT Kernels (State VNameSource) () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    [HistOp Kernels]
ops' <- [HistOp SOACS]
-> (HistOp SOACS
    -> BinderT Kernels (State VNameSource) (HistOp Kernels))
-> BinderT Kernels (State VNameSource) [HistOp Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS
  -> BinderT Kernels (State VNameSource) (HistOp Kernels))
 -> BinderT Kernels (State VNameSource) [HistOp Kernels])
-> (HistOp SOACS
    -> BinderT Kernels (State VNameSource) (HistOp Kernels))
-> BinderT Kernels (State VNameSource) [HistOp Kernels]
forall a b. (a -> b) -> a -> b
$ \(SOAC.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes Lambda
op) -> do
      (LambdaT Kernels
op', Result
nes', Shape
shape) <- Lambda
-> Result
-> BinderT
     Kernels (State VNameSource) (LambdaT Kernels, Result, Shape)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
Lambda -> Result -> m (LambdaT Kernels, Result, Shape)
determineReduceOp Lambda
op Result
nes
      HistOp Kernels
-> BinderT Kernels (State VNameSource) (HistOp Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp Kernels
 -> BinderT Kernels (State VNameSource) (HistOp Kernels))
-> HistOp Kernels
-> BinderT Kernels (State VNameSource) (HistOp Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> LambdaT Kernels
-> HistOp Kernels
forall lore.
SubExp
-> SubExp
-> [VName]
-> Result
-> Shape
-> Lambda lore
-> HistOp lore
Out.HistOp SubExp
num_bins SubExp
rf [VName]
dests Result
nes' Shape
shape LambdaT Kernels
op'

    let isDest :: VName -> Bool
isDest = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem ([VName] -> VName -> Bool) -> [VName] -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ (HistOp Kernels -> [VName]) -> [HistOp Kernels] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp Kernels -> [VName]
forall lore. HistOp lore -> [VName]
Out.histDest [HistOp Kernels]
ops'
        inputs' :: [KernelInput]
inputs' = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (KernelInput -> Bool) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isDest (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputArray) [KernelInput]
inputs

    Certificates
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT Kernels (State VNameSource) ()
 -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
      Stms Kernels -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm Kernels -> BinderT Kernels (State VNameSource) (Stm Kernels))
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm Kernels -> BinderT Kernels (State VNameSource) (Stm Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms Kernels
 -> BinderT Kernels (State VNameSource) (Stms Kernels))
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      SegLevel
-> Pattern Kernels
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp Kernels]
-> LambdaT Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp Kernels]
-> LambdaT Kernels
-> [VName]
-> m (Stms Kernels)
segHist SegLevel
lvl Pattern
Pattern Kernels
orig_pat SubExp
hist_w [(VName, SubExp)]
ispace [KernelInput]
inputs' [HistOp Kernels]
ops' LambdaT Kernels
lam [VName]
arrs

determineReduceOp :: (MonadBinder m, Lore m ~ Out.Kernels) =>
                     Lambda -> [SubExp] -> m (Out.Lambda Out.Kernels, [SubExp], Shape)
determineReduceOp :: Lambda -> Result -> m (LambdaT Kernels, Result, Shape)
determineReduceOp Lambda
lam Result
nes =
  -- FIXME? We are assuming that the accumulator is a replicate, and
  -- we fish out its value in a gross way.
  case (SubExp -> Maybe VName) -> Result -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar Result
nes of
    Just [VName]
ne_vs' -> do
      let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
      Result
nes' <- [VName] -> (VName -> m SubExp) -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ne_vs' ((VName -> m SubExp) -> m Result)
-> (VName -> m SubExp) -> m Result
forall a b. (a -> b) -> a -> b
$ \VName
ne_v -> do
        Type
ne_v_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
ne_v
        String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"hist_ne" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp Kernels
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
ne_v (Slice SubExp -> BasicOp Kernels)
-> Slice SubExp -> BasicOp Kernels
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
ne_v_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
          Int -> DimIndex SubExp -> Slice SubExp
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) (DimIndex SubExp -> Slice SubExp)
-> DimIndex SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0
      let lam'' :: LambdaT Kernels
lam'' = Lambda -> LambdaT Kernels
soacsLambdaToKernels Lambda
lam'
      (LambdaT Kernels, Result, Shape)
-> m (LambdaT Kernels, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaT Kernels
lam'', Result
nes', Shape
shape)
    Maybe [VName]
Nothing -> do
      let lam' :: LambdaT Kernels
lam' = Lambda -> LambdaT Kernels
soacsLambdaToKernels Lambda
lam
      (LambdaT Kernels, Result, Shape)
-> m (LambdaT Kernels, Result, Shape)
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaT Kernels
lam', Result
nes, Shape
forall a. Monoid a => a
mempty)

isVectorMap :: Lambda -> (Shape, Lambda)
isVectorMap :: Lambda -> (Shape, Lambda)
isVectorMap Lambda
lam
  | [Let (Pattern [] [PatElem]
pes) StmAux (ExpAttr SOACS)
_ (Op (Screma w form arrs))] <-
      Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm]) -> Stms SOACS -> [Stm]
forall a b. (a -> b) -> a -> b
$ Body -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body -> Stms SOACS) -> Body -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam,
    Body -> Result
forall lore. BodyT lore -> Result
bodyResult (Lambda -> Body
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
[PatElem]
pes,
    Just Lambda
map_lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
    [VName]
arrs [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) =
      let (Shape
shape, Lambda
lam') = Lambda -> (Shape, Lambda)
isVectorMap Lambda
map_lam
      in (Result -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape, Lambda
lam')
  | Bool
otherwise = (Shape
forall a. Monoid a => a
mempty, Lambda
lam)

segmentedScanomapKernel :: MonadFreshNames m =>
                           KernelNest
                        -> [Int]
                        -> SubExp
                        -> Out.Lambda Out.Kernels -> Out.Lambda Out.Kernels
                        -> [SubExp] -> [VName]
                        -> DistNestT m (Maybe KernelsStms)
segmentedScanomapKernel :: KernelNest
-> [Int]
-> SubExp
-> LambdaT Kernels
-> LambdaT Kernels
-> Result
-> [VName]
-> DistNestT m (Maybe (Stms Kernels))
segmentedScanomapKernel KernelNest
nest [Int]
perm SubExp
segment_size LambdaT Kernels
lam LambdaT Kernels
map_lam Result
nes [VName]
arrs = do
  MkSegLevel m
mk_lvl <- (DistEnv m -> MkSegLevel m) -> DistNestT m (MkSegLevel m)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv m -> MkSegLevel m
forall (m :: * -> *). DistEnv m -> MkSegLevel m
distSegLevel
  KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
isSegmentedOp KernelNest
nest [Int]
perm SubExp
segment_size (LambdaT Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn LambdaT Kernels
lam) (LambdaT Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn LambdaT Kernels
map_lam) Result
nes [VName]
arrs ((Pattern
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> Result
  -> [VName]
  -> [VName]
  -> BinderT Kernels m ())
 -> DistNestT m (Maybe (Stms Kernels)))
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
forall a b. (a -> b) -> a -> b
$
    \Pattern
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ [VName]
_ -> do
    SegLevel
lvl <- MkSegLevel m
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segscan" (ThreadRecommendation -> BinderT Kernels m SegLevel)
-> ThreadRecommendation -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
    Stms Kernels -> BinderT Kernels m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> BinderT Kernels m ())
-> BinderT Kernels m (Stms Kernels) -> BinderT Kernels m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm Kernels -> BinderT Kernels m (Stm Kernels))
-> Stms Kernels -> BinderT Kernels m (Stms Kernels)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm Kernels -> BinderT Kernels m (Stm Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms Kernels -> BinderT Kernels m (Stms Kernels))
-> BinderT Kernels m (Stms Kernels)
-> BinderT Kernels m (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      SegLevel
-> Pattern Kernels
-> SubExp
-> LambdaT Kernels
-> LambdaT Kernels
-> Result
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT Kernels m (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> LambdaT Kernels
-> LambdaT Kernels
-> Result
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms Kernels)
segScan SegLevel
lvl Pattern
Pattern Kernels
pat SubExp
segment_size LambdaT Kernels
lam LambdaT Kernels
map_lam Result
nes' [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

regularSegmentedRedomapKernel :: MonadFreshNames m =>
                                 KernelNest
                              -> [Int]
                              -> SubExp -> Commutativity
                              -> Out.Lambda Out.Kernels -> Out.Lambda Out.Kernels
                              -> [SubExp] -> [VName]
                              -> DistNestT m (Maybe KernelsStms)
regularSegmentedRedomapKernel :: KernelNest
-> [Int]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> LambdaT Kernels
-> Result
-> [VName]
-> DistNestT m (Maybe (Stms Kernels))
regularSegmentedRedomapKernel KernelNest
nest [Int]
perm SubExp
segment_size Commutativity
comm LambdaT Kernels
lam LambdaT Kernels
map_lam Result
nes [VName]
arrs = do
  MkSegLevel m
mk_lvl <- (DistEnv m -> MkSegLevel m) -> DistNestT m (MkSegLevel m)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv m -> MkSegLevel m
forall (m :: * -> *). DistEnv m -> MkSegLevel m
distSegLevel
  KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
isSegmentedOp KernelNest
nest [Int]
perm SubExp
segment_size (LambdaT Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn LambdaT Kernels
lam) (LambdaT Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn LambdaT Kernels
map_lam) Result
nes [VName]
arrs ((Pattern
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> Result
  -> [VName]
  -> [VName]
  -> BinderT Kernels m ())
 -> DistNestT m (Maybe (Stms Kernels)))
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
forall a b. (a -> b) -> a -> b
$
    \Pattern
pat [(VName, SubExp)]
ispace [KernelInput]
inps Result
nes' [VName]
_ [VName]
_ -> do
      let red_op :: SegRedOp Kernels
red_op = Commutativity
-> LambdaT Kernels -> Result -> Shape -> SegRedOp Kernels
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegRedOp lore
SegRedOp Commutativity
comm LambdaT Kernels
lam Result
nes' Shape
forall a. Monoid a => a
mempty
      SegLevel
lvl <- MkSegLevel m
mk_lvl (SubExp
segment_size SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segred" (ThreadRecommendation -> BinderT Kernels m SegLevel)
-> ThreadRecommendation -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      Stms Kernels -> BinderT Kernels m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> BinderT Kernels m ())
-> BinderT Kernels m (Stms Kernels) -> BinderT Kernels m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm Kernels -> BinderT Kernels m (Stm Kernels))
-> Stms Kernels -> BinderT Kernels m (Stms Kernels)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm Kernels -> BinderT Kernels m (Stm Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms Kernels -> BinderT Kernels m (Stms Kernels))
-> BinderT Kernels m (Stms Kernels)
-> BinderT Kernels m (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
        SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> LambdaT Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT Kernels m (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> LambdaT Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms Kernels)
segRed SegLevel
lvl Pattern
Pattern Kernels
pat SubExp
segment_size [SegRedOp Kernels
red_op] LambdaT Kernels
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps

isSegmentedOp :: MonadFreshNames m =>
                 KernelNest
              -> [Int]
              -> SubExp
              -> Names -> Names
              -> [SubExp] -> [VName]
              -> (Pattern
                  -> [(VName, SubExp)]
                  -> [KernelInput]
                  -> [SubExp] -> [VName]  -> [VName]
                  -> BinderT Out.Kernels m ())
              -> DistNestT m (Maybe KernelsStms)
isSegmentedOp :: KernelNest
-> [Int]
-> SubExp
-> Names
-> Names
-> Result
-> [VName]
-> (Pattern
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> Result
    -> [VName]
    -> [VName]
    -> BinderT Kernels m ())
-> DistNestT m (Maybe (Stms Kernels))
isSegmentedOp KernelNest
nest [Int]
perm SubExp
segment_size Names
free_in_op Names
_free_in_fold_op Result
nes [VName]
arrs Pattern
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT Kernels m ()
m = MaybeT (DistNestT m) (Stms Kernels)
-> DistNestT m (Maybe (Stms Kernels))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT (DistNestT m) (Stms Kernels)
 -> DistNestT m (Maybe (Stms Kernels)))
-> MaybeT (DistNestT m) (Stms Kernels)
-> DistNestT m (Maybe (Stms Kernels))
forall a b. (a -> b) -> a -> b
$ do
  -- We must verify that array inputs to the operation are inputs to
  -- the outermost loop nesting or free in the loop nest.  Nothing
  -- free in the op may be bound by the nest.  Furthermore, the
  -- neutral elements must be free in the loop nest.
  --
  -- We must summarise any names from free_in_op that are bound in the
  -- nest, and describe how to obtain them given segment indices.

  let bound_by_nest :: Names
bound_by_nest = KernelNest -> Names
boundInKernelNest KernelNest
nest

  ([(VName, SubExp)]
ispace, [KernelInput]
kernel_inps) <- KernelNest
-> MaybeT (DistNestT m) ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest

  Bool -> MaybeT (DistNestT m) () -> MaybeT (DistNestT m) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Names
free_in_op Names -> Names -> Bool
`namesIntersect` Names
bound_by_nest) (MaybeT (DistNestT m) () -> MaybeT (DistNestT m) ())
-> MaybeT (DistNestT m) () -> MaybeT (DistNestT m) ()
forall a b. (a -> b) -> a -> b
$
    String -> MaybeT (DistNestT m) ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Non-fold lambda uses nest-bound parameters."

  let indices :: [VName]
indices = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace

      prepareNe :: SubExp -> MaybeT (DistNestT m) SubExp
prepareNe (Var VName
v) | VName
v VName -> Names -> Bool
`nameIn` Names
bound_by_nest =
                            String -> MaybeT (DistNestT m) SubExp
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Neutral element bound in nest"
      prepareNe SubExp
ne = SubExp -> MaybeT (DistNestT m) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
ne

      prepareArr :: VName -> MaybeT (DistNestT m) (BinderT Kernels m VName)
prepareArr VName
arr =
        case (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps of
          Just KernelInput
inp
            | KernelInput -> Result
kernelInputIndices KernelInput
inp Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
indices ->
                BinderT Kernels m VName
-> MaybeT (DistNestT m) (BinderT Kernels m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT Kernels m VName
 -> MaybeT (DistNestT m) (BinderT Kernels m VName))
-> BinderT Kernels m VName
-> MaybeT (DistNestT m) (BinderT Kernels m VName)
forall a b. (a -> b) -> a -> b
$ VName -> BinderT Kernels m VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> BinderT Kernels m VName)
-> VName -> BinderT Kernels m VName
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
            | Bool -> Bool
not (KernelInput -> VName
kernelInputArray KernelInput
inp VName -> Names -> Bool
`nameIn` Names
bound_by_nest) ->
                BinderT Kernels m VName
-> MaybeT (DistNestT m) (BinderT Kernels m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT Kernels m VName
 -> MaybeT (DistNestT m) (BinderT Kernels m VName))
-> BinderT Kernels m VName
-> MaybeT (DistNestT m) (BinderT Kernels m VName)
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> KernelInput -> BinderT Kernels m VName
forall (m :: * -> *).
MonadBinder m =>
[(VName, SubExp)] -> KernelInput -> m VName
replicateMissing [(VName, SubExp)]
ispace KernelInput
inp
          Maybe KernelInput
Nothing | Bool -> Bool
not (VName
arr VName -> Names -> Bool
`nameIn` Names
bound_by_nest) ->
                      -- This input is something that is free inside
                      -- the loop nesting. We will have to replicate
                      -- it.
                      BinderT Kernels m VName
-> MaybeT (DistNestT m) (BinderT Kernels m VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinderT Kernels m VName
 -> MaybeT (DistNestT m) (BinderT Kernels m VName))
-> BinderT Kernels m VName
-> MaybeT (DistNestT m) (BinderT Kernels m VName)
forall a b. (a -> b) -> a -> b
$
                      String -> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_repd")
                      (BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp Kernels
forall lore. Shape -> SubExp -> BasicOp lore
Replicate (Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) (SubExp -> BasicOp Kernels) -> SubExp -> BasicOp Kernels
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
          Maybe KernelInput
_ ->
            String -> MaybeT (DistNestT m) (BinderT Kernels m VName)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Input not free or outermost."

  Result
nes' <- (SubExp -> MaybeT (DistNestT m) SubExp)
-> Result -> MaybeT (DistNestT m) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> MaybeT (DistNestT m) SubExp
prepareNe Result
nes

  [BinderT Kernels m VName]
mk_arrs <- (VName -> MaybeT (DistNestT m) (BinderT Kernels m VName))
-> [VName] -> MaybeT (DistNestT m) [BinderT Kernels m VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> MaybeT (DistNestT m) (BinderT Kernels m VName)
prepareArr [VName]
arrs
  Scope Kernels
scope <- DistNestT m (Scope Kernels) -> MaybeT (DistNestT m) (Scope Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift DistNestT m (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope

  DistNestT m (Stms Kernels) -> MaybeT (DistNestT m) (Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT m (Stms Kernels) -> MaybeT (DistNestT m) (Stms Kernels))
-> DistNestT m (Stms Kernels)
-> MaybeT (DistNestT m) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ m (Stms Kernels) -> DistNestT m (Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms Kernels) -> DistNestT m (Stms Kernels))
-> m (Stms Kernels) -> DistNestT m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (BinderT Kernels m () -> Scope Kernels -> m (Stms Kernels))
-> Scope Kernels -> BinderT Kernels m () -> m (Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT Kernels m () -> Scope Kernels -> m (Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (Stms lore)
runBinderT_ Scope Kernels
scope (BinderT Kernels m () -> m (Stms Kernels))
-> BinderT Kernels m () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    -- We must make sure all inputs are of size
    -- segment_size*nesting_size.
    SubExp
total_num_elements <-
      String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"total_num_elements" (Exp Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (Exp Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      BinOp
-> SubExp
-> Result
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> BinOp
Mul IntType
Int32) SubExp
segment_size (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)

    let flatten :: VName -> BinderT Kernels m VName
flatten VName
arr = do
          Shape
arr_shape <- Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape)
-> BinderT Kernels m Type -> BinderT Kernels m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> BinderT Kernels m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          -- CHECKME: is the length the right thing here?  We want to
          -- reproduce the parameter type.
          let reshape :: ShapeChange SubExp
reshape = ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter [SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew SubExp
total_num_elements]
                        (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
+[LoopNesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest)) Shape
arr_shape
          String -> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_flat") (Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m VName)
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m VName
forall a b. (a -> b) -> a -> b
$
            BasicOp Kernels -> Exp Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> Exp Kernels) -> BasicOp Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp Kernels
forall lore. ShapeChange SubExp -> VName -> BasicOp lore
Reshape ShapeChange SubExp
reshape VName
arr

    [VName]
nested_arrs <- [BinderT Kernels m VName] -> BinderT Kernels m [VName]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [BinderT Kernels m VName]
mk_arrs
    [VName]
arrs' <- (VName -> BinderT Kernels m VName)
-> [VName] -> BinderT Kernels m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT Kernels m VName
flatten [VName]
nested_arrs

    let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$
              PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements (PatternT Type -> [PatElemT Type])
-> PatternT Type -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pattern Kernels
loopNestingPattern (LoopNesting -> Pattern Kernels) -> LoopNesting -> Pattern Kernels
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    Pattern
-> [(VName, SubExp)]
-> [KernelInput]
-> Result
-> [VName]
-> [VName]
-> BinderT Kernels m ()
m PatternT Type
Pattern
pat [(VName, SubExp)]
ispace [KernelInput]
kernel_inps Result
nes' [VName]
nested_arrs [VName]
arrs'

  where replicateMissing :: [(VName, SubExp)] -> KernelInput -> m VName
replicateMissing [(VName, SubExp)]
ispace KernelInput
inp = do
          Type
t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType (VName -> m Type) -> VName -> m Type
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
          let inp_is :: Result
inp_is = KernelInput -> Result
kernelInputIndices KernelInput
inp
              shapes :: [Shape]
shapes = [(VName, SubExp)] -> Result -> [Shape]
forall d. [(VName, d)] -> Result -> [ShapeBase d]
determineRepeats [(VName, SubExp)]
ispace Result
inp_is
              ([Shape]
outer_shapes, Shape
inner_shape) = [Shape] -> Type -> ([Shape], Shape)
repeatShapes [Shape]
shapes Type
t
          String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"repeated" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
            [Shape] -> Shape -> VName -> BasicOp (Lore m)
forall lore. [Shape] -> Shape -> VName -> BasicOp lore
Repeat [Shape]
outer_shapes Shape
inner_shape (VName -> BasicOp (Lore m)) -> VName -> BasicOp (Lore m)
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp

        determineRepeats :: [(VName, d)] -> Result -> [ShapeBase d]
determineRepeats [(VName, d)]
ispace (SubExp
i:Result
is)
          | ([(VName, d)]
skipped_ispace, [(VName, d)]
ispace') <- ((VName, d) -> Bool)
-> [(VName, d)] -> ([(VName, d)], [(VName, d)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span ((SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/=SubExp
i) (SubExp -> Bool) -> ((VName, d) -> SubExp) -> (VName, d) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> ((VName, d) -> VName) -> (VName, d) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, d) -> VName
forall a b. (a, b) -> a
fst) [(VName, d)]
ispace =
              [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape (((VName, d) -> d) -> [(VName, d)] -> [d]
forall a b. (a -> b) -> [a] -> [b]
map (VName, d) -> d
forall a b. (a, b) -> b
snd [(VName, d)]
skipped_ispace) ShapeBase d -> [ShapeBase d] -> [ShapeBase d]
forall a. a -> [a] -> [a]
: [(VName, d)] -> Result -> [ShapeBase d]
determineRepeats (Int -> [(VName, d)] -> [(VName, d)]
forall a. Int -> [a] -> [a]
drop Int
1 [(VName, d)]
ispace') Result
is
        determineRepeats [(VName, d)]
ispace Result
_ =
          [[d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape ([d] -> ShapeBase d) -> [d] -> ShapeBase d
forall a b. (a -> b) -> a -> b
$ ((VName, d) -> d) -> [(VName, d)] -> [d]
forall a b. (a -> b) -> [a] -> [b]
map (VName, d) -> d
forall a b. (a, b) -> b
snd [(VName, d)]
ispace]

permutationAndMissing :: Pattern -> [SubExp] -> Maybe ([Int], [PatElem])
permutationAndMissing :: Pattern -> Result -> Maybe ([Int], [PatElem])
permutationAndMissing Pattern
pat Result
res = do
  let pes :: [PatElemT Type]
pes = PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT Type
Pattern
pat
      ([PatElemT Type]
_used,[PatElemT Type]
unused) =
        (PatElemT Type -> Bool)
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> Names -> Bool
`nameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res) (VName -> Bool)
-> (PatElemT Type -> VName) -> PatElemT Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
pes
      res_expanded :: Result
res_expanded = Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
unused
  [Int]
perm <- (PatElemT Type -> SubExp) -> [PatElemT Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName) [PatElemT Type]
pes Result -> Result -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` Result
res_expanded
  ([Int], [PatElemT Type]) -> Maybe ([Int], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [PatElemT Type]
unused)

-- Add extra pattern elements to every kernel nesting level.
expandKernelNest :: MonadFreshNames m =>
                    [PatElem] -> KernelNest -> m KernelNest
expandKernelNest :: [PatElem] -> KernelNest -> m KernelNest
expandKernelNest [PatElem]
pes (LoopNesting
outer_nest, [LoopNesting]
inner_nests) = do
  let outer_size :: Result
outer_size = LoopNesting -> SubExp
loopNestingWidth LoopNesting
outer_nest SubExp -> Result -> Result
forall a. a -> [a] -> [a]
:
                   (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
      inner_sizes :: [Result]
inner_sizes = Result -> [Result]
forall a. [a] -> [[a]]
tails (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
  LoopNesting
outer_nest' <- LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
outer_nest Result
outer_size
  [LoopNesting]
inner_nests' <- (LoopNesting -> Result -> m LoopNesting)
-> [LoopNesting] -> [Result] -> m [LoopNesting]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM LoopNesting -> Result -> m LoopNesting
expandWith [LoopNesting]
inner_nests [Result]
inner_sizes
  KernelNest -> m KernelNest
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopNesting
outer_nest', [LoopNesting]
inner_nests')
  where expandWith :: LoopNesting -> Result -> m LoopNesting
expandWith LoopNesting
nest Result
dims = do
           [PatElemT Type]
pes' <- (PatElemT Type -> m (PatElemT Type))
-> [PatElemT Type] -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Result -> PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) attr.
(MonadFreshNames m, Typed attr) =>
Result -> PatElemT attr -> m (PatElemT Type)
expandPatElemWith Result
dims) [PatElemT Type]
[PatElem]
pes
           LoopNesting -> m LoopNesting
forall (m :: * -> *) a. Monad m => a -> m a
return LoopNesting
nest { loopNestingPattern :: Pattern Kernels
loopNestingPattern =
                           [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$
                           PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements (LoopNesting -> Pattern Kernels
loopNestingPattern LoopNesting
nest) [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type]
pes'
                       }

        expandPatElemWith :: Result -> PatElemT attr -> m (PatElemT Type)
expandPatElemWith Result
dims PatElemT attr
pe = do
          VName
name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
pe
          PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT attr
pe { patElemName :: VName
patElemName = VName
name
                    , patElemAttr :: Type
patElemAttr = PatElemT attr -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT attr
pe Type -> Shape -> Type
`arrayOfShape` Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
dims
                    }

kernelOrNot :: MonadFreshNames m =>
               Certificates -> Stm -> DistAcc
            -> PostKernels -> DistAcc -> Maybe KernelsStms
            -> DistNestT m DistAcc
kernelOrNot :: Certificates
-> Stm
-> DistAcc
-> PostKernels
-> DistAcc
-> Maybe (Stms Kernels)
-> DistNestT m DistAcc
kernelOrNot Certificates
cs Stm
bnd DistAcc
acc PostKernels
_ DistAcc
_ Maybe (Stms Kernels)
Nothing =
  Stm -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *). Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel (Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs Stm
bnd) DistAcc
acc
kernelOrNot Certificates
cs Stm
_ DistAcc
_ PostKernels
kernels DistAcc
acc' (Just Stms Kernels
bnds) = do
  PostKernels -> DistNestT m ()
forall (m :: * -> *). Monad m => PostKernels -> DistNestT m ()
addKernels PostKernels
kernels
  Stms Kernels -> DistNestT m ()
forall (m :: * -> *). Monad m => Stms Kernels -> DistNestT m ()
addKernel (Stms Kernels -> DistNestT m ()) -> Stms Kernels -> DistNestT m ()
forall a b. (a -> b) -> a -> b
$ (Stm Kernels -> Stm Kernels) -> Stms Kernels -> Stms Kernels
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm Kernels -> Stm Kernels
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) Stms Kernels
bnds
  DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc
acc'

distributeMap :: MonadFreshNames m => MapLoop -> DistAcc -> DistNestT m DistAcc
distributeMap :: MapLoop -> DistAcc -> DistNestT m DistAcc
distributeMap maploop :: MapLoop
maploop@(MapLoop Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs) DistAcc
acc =
  DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> DistNestT m DistAcc
distribute (DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
  MapLoop -> DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
Monad m =>
MapLoop -> DistAcc -> DistNestT m DistAcc
leavingNesting MapLoop
maploop (DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
  Pattern
-> Certificates
-> SubExp
-> Lambda
-> [VName]
-> DistNestT m DistAcc
-> DistNestT m DistAcc
forall (m :: * -> *) a.
Monad m =>
Pattern
-> Certificates
-> SubExp
-> Lambda
-> [VName]
-> DistNestT m a
-> DistNestT m a
mapNesting Pattern
pat Certificates
cs SubExp
w Lambda
lam [VName]
arrs
  (DistAcc -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> DistNestT m DistAcc
distribute (DistAcc -> DistNestT m DistAcc)
-> DistNestT m DistAcc -> DistNestT m DistAcc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc -> Stms SOACS -> DistNestT m DistAcc
forall (m :: * -> *).
MonadFreshNames m =>
DistAcc -> Stms SOACS -> DistNestT m DistAcc
distributeMapBodyStms DistAcc
acc' Stms SOACS
lam_bnds)

  where acc' :: DistAcc
acc' = DistAcc :: Targets -> Stms Kernels -> DistAcc
DistAcc { distTargets :: Targets
distTargets = Target -> Targets -> Targets
pushInnerTarget
                                       (Pattern
Pattern Kernels
pat, Body -> Result
forall lore. BodyT lore -> Result
bodyResult (Body -> Result) -> Body -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> Body
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) (Targets -> Targets) -> Targets -> Targets
forall a b. (a -> b) -> a -> b
$
                                       DistAcc -> Targets
distTargets DistAcc
acc
                       , distStms :: Stms Kernels
distStms = Stms Kernels
forall a. Monoid a => a
mempty
                       }

        lam_bnds :: Stms SOACS
lam_bnds = Body -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (Body -> Stms SOACS) -> Body -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> Body
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam